You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_amplitude_to_db.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing AmplitudeToDB op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.audio.transforms as c_audio
  22. from mindspore import log as logger
  23. from mindspore.dataset.audio.utils import ScaleType
  24. CHANNEL = 1
  25. FREQ = 20
  26. TIME = 15
  27. def gen(shape):
  28. np.random.seed(0)
  29. data = np.random.random(shape)
  30. yield(np.array(data, dtype=np.float32),)
  31. def _count_unequal_element(data_expected, data_me, rtol, atol):
  32. """ Precision calculation func """
  33. assert data_expected.shape == data_me.shape
  34. total_count = len(data_expected.flatten())
  35. error = np.abs(data_expected - data_me)
  36. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  37. loss_count = np.count_nonzero(greater)
  38. assert (loss_count / total_count) < rtol, \
  39. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  40. format(data_expected[greater], data_me[greater], error[greater])
  41. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  42. """ Precision calculation formula """
  43. if np.any(np.isnan(data_expected)):
  44. assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
  45. elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
  46. _count_unequal_element(data_expected, data_me, rtol, atol)
  47. else:
  48. assert True
  49. def test_func_amplitude_to_db_eager():
  50. """ mindspore eager mode normal testcase:amplitude_to_db op"""
  51. logger.info("check amplitude_to_db op output")
  52. ndarr_in = np.array([[[[-0.2197528, 0.3821656]]],
  53. [[[0.57418776, 0.46741104]]],
  54. [[[-0.20381108, -0.9303914]]],
  55. [[[0.3693608, -0.2017813]]],
  56. [[[-1.727381, -1.3708513]]],
  57. [[[1.259975, 0.4981323]]],
  58. [[[0.76986176, -0.5793846]]]]).astype(np.float32)
  59. # cal from benchmark
  60. out_expect = np.array([[[[-84.17748, -4.177484]]],
  61. [[[-2.4094608, -3.3030105]]],
  62. [[[-100., -100.]]],
  63. [[[-4.325492, -84.32549]]],
  64. [[[-100., -100.]]],
  65. [[[1.0036192, -3.0265532]]],
  66. [[[-1.1358725, -81.13587]]]]).astype(np.float32)
  67. amplitude_to_db_op = c_audio.AmplitudeToDB()
  68. out_mindspore = amplitude_to_db_op(ndarr_in)
  69. allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001)
  70. def test_func_amplitude_to_db_pipeline():
  71. """ mindspore pipeline mode normal testcase:amplitude_to_db op"""
  72. logger.info("test AmplitudeToDB op with default value")
  73. generator = gen([CHANNEL, FREQ, TIME])
  74. data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"])
  75. transforms = [
  76. c_audio.AmplitudeToDB()
  77. ]
  78. data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
  79. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  80. out_put = item["multi_dimensional_data"]
  81. assert out_put.shape == (CHANNEL, FREQ, TIME)
  82. def test_amplitude_to_db_invalid_input():
  83. def test_invalid_input(test_name, stype, ref_value, amin, top_db, error, error_msg):
  84. logger.info("Test AmplitudeToDB with bad input: {0}".format(test_name))
  85. with pytest.raises(error) as error_info:
  86. c_audio.AmplitudeToDB(stype=stype, ref_value=ref_value, amin=amin, top_db=top_db)
  87. assert error_msg in str(error_info.value)
  88. test_invalid_input("invalid stype parameter value", "test", 1.0, 1e-10, 80.0, TypeError,
  89. "Argument stype with value test is not of type [<enum 'ScaleType'>], but got <class 'str'>.")
  90. test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, -1.0, 1e-10, 80.0, ValueError,
  91. "Input ref_value is not within the required interval of (0, 16777216]")
  92. test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, -1e-10, 80.0, ValueError,
  93. "Input amin is not within the required interval of (0, 16777216]")
  94. test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, -80.0, ValueError,
  95. "Input top_db is not within the required interval of (0, 16777216]")
  96. test_invalid_input("invalid stype parameter value", True, 1.0, 1e-10, 80.0, TypeError,
  97. "Argument stype with value True is not of type [<enum 'ScaleType'>], but got <class 'bool'>.")
  98. test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, "value", 1e-10, 80.0, TypeError,
  99. "Argument ref_value with value value is not of type [<class 'int'>, <class 'float'>], " +
  100. "but got <class 'str'>")
  101. test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, "value", -80.0, TypeError,
  102. "Argument amin with value value is not of type [<class 'int'>, <class 'float'>], " +
  103. "but got <class 'str'>")
  104. test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, "value", TypeError,
  105. "Argument top_db with value value is not of type [<class 'int'>, <class 'float'>], " +
  106. "but got <class 'str'>")
  107. if __name__ == "__main__":
  108. test_func_amplitude_to_db_eager()
  109. test_func_amplitude_to_db_pipeline()
  110. test_amplitude_to_db_invalid_input()