You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_amplitude_to_db.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing AmplitudeToDB op in DE
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.audio.transforms as c_audio
  22. from mindspore import log as logger
  23. from mindspore.dataset.audio.utils import ScaleType
  24. CHANNEL = 1
  25. FREQ = 20
  26. TIME = 15
  27. def gen(shape):
  28. np.random.seed(0)
  29. data = np.random.random(shape)
  30. yield (np.array(data, dtype=np.float32),)
  31. def count_unequal_element(data_expected, data_me, rtol, atol):
  32. """ Precision calculation func """
  33. assert data_expected.shape == data_me.shape
  34. total_count = len(data_expected.flatten())
  35. error = np.abs(data_expected - data_me)
  36. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  37. loss_count = np.count_nonzero(greater)
  38. assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
  39. data_expected[greater], data_me[greater], error[greater])
  40. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  41. """ Precision calculation formula """
  42. if np.any(np.isnan(data_expected)):
  43. assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
  44. elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
  45. count_unequal_element(data_expected, data_me, rtol, atol)
  46. def test_func_amplitude_to_db_eager():
  47. """ mindspore eager mode normal testcase:amplitude_to_db op"""
  48. logger.info("check amplitude_to_db op output")
  49. ndarr_in = np.array([[[[-0.2197528, 0.3821656]]],
  50. [[[0.57418776, 0.46741104]]],
  51. [[[-0.20381108, -0.9303914]]],
  52. [[[0.3693608, -0.2017813]]],
  53. [[[-1.727381, -1.3708513]]],
  54. [[[1.259975, 0.4981323]]],
  55. [[[0.76986176, -0.5793846]]]]).astype(np.float32)
  56. # cal from benchmark
  57. out_expect = np.array([[[[-84.17748, -4.177484]]],
  58. [[[-2.4094608, -3.3030105]]],
  59. [[[-100., -100.]]],
  60. [[[-4.325492, -84.32549]]],
  61. [[[-100., -100.]]],
  62. [[[1.0036192, -3.0265532]]],
  63. [[[-1.1358725, -81.13587]]]]).astype(np.float32)
  64. amplitude_to_db_op = c_audio.AmplitudeToDB()
  65. out_mindspore = amplitude_to_db_op(ndarr_in)
  66. allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001)
  67. def test_func_amplitude_to_db_pipeline():
  68. """ mindspore pipeline mode normal testcase:amplitude_to_db op"""
  69. logger.info("test AmplitudeToDB op with default value")
  70. generator = gen([CHANNEL, FREQ, TIME])
  71. data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"])
  72. transforms = [c_audio.AmplitudeToDB()]
  73. data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
  74. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  75. out_put = item["multi_dimensional_data"]
  76. assert out_put.shape == (CHANNEL, FREQ, TIME)
  77. def test_amplitude_to_db_invalid_input():
  78. def test_invalid_input(test_name, stype, ref_value, amin, top_db, error, error_msg):
  79. logger.info("Test AmplitudeToDB with bad input: {0}".format(test_name))
  80. with pytest.raises(error) as error_info:
  81. c_audio.AmplitudeToDB(stype=stype, ref_value=ref_value, amin=amin, top_db=top_db)
  82. assert error_msg in str(error_info.value)
  83. test_invalid_input("invalid stype parameter value", "test", 1.0, 1e-10, 80.0, TypeError,
  84. "Argument stype with value test is not of type [<enum 'ScaleType'>], but got <class 'str'>.")
  85. test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, -1.0, 1e-10, 80.0, ValueError,
  86. "Input ref_value is not within the required interval of (0, 16777216]")
  87. test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, -1e-10, 80.0, ValueError,
  88. "Input amin is not within the required interval of (0, 16777216]")
  89. test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, -80.0, ValueError,
  90. "Input top_db is not within the required interval of (0, 16777216]")
  91. test_invalid_input("invalid stype parameter value", True, 1.0, 1e-10, 80.0, TypeError,
  92. "Argument stype with value True is not of type [<enum 'ScaleType'>], but got <class 'bool'>.")
  93. test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, "value", 1e-10, 80.0, TypeError,
  94. "Argument ref_value with value value is not of type [<class 'int'>, <class 'float'>], " +
  95. "but got <class 'str'>")
  96. test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, "value", -80.0, TypeError,
  97. "Argument amin with value value is not of type [<class 'int'>, <class 'float'>], " +
  98. "but got <class 'str'>")
  99. test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, "value", TypeError,
  100. "Argument top_db with value value is not of type [<class 'int'>, <class 'float'>], " +
  101. "but got <class 'str'>")
  102. if __name__ == "__main__":
  103. test_func_amplitude_to_db_eager()
  104. test_func_amplitude_to_db_pipeline()
  105. test_amplitude_to_db_invalid_input()