You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_time_masking.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing TimeMasking op in DE.
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. import mindspore.dataset.audio.transforms as audio
  22. from mindspore import log as logger
  23. CHANNEL = 2
  24. FREQ = 20
  25. TIME = 30
  26. def gen(shape):
  27. np.random.seed(0)
  28. data = np.random.random(shape)
  29. yield (np.array(data, dtype=np.float32),)
  30. def count_unequal_element(data_expected, data_me, rtol, atol):
  31. """ Precision calculation func """
  32. assert data_expected.shape == data_me.shape
  33. total_count = len(data_expected.flatten())
  34. error = np.abs(data_expected - data_me)
  35. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  36. loss_count = np.count_nonzero(greater)
  37. assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
  38. data_expected[greater], data_me[greater], error[greater])
  39. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  40. """ Precision calculation formula """
  41. if np.any(np.isnan(data_expected)):
  42. assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
  43. elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
  44. count_unequal_element(data_expected, data_me, rtol, atol)
  45. def test_func_time_masking_eager_random_input():
  46. """ mindspore eager mode normal testcase:time_masking op"""
  47. logger.info("test time_masking op")
  48. spectrogram = next(gen((CHANNEL, FREQ, TIME)))[0]
  49. out_put = audio.TimeMasking(False, 3, 1, 10)(spectrogram)
  50. assert out_put.shape == (CHANNEL, FREQ, TIME)
  51. def test_func_time_masking_eager_precision():
  52. """ mindspore eager mode normal testcase:time_masking op"""
  53. logger.info("test time_masking op")
  54. spectrogram = np.array([[[0.17274511, 0.85174704, 0.07162686, -0.45436913],
  55. [-1.045921, -1.8204843, 0.62333095, -0.09532598],
  56. [1.8175547, -0.25779432, -0.58152324, -0.00221091]],
  57. [[-1.205032, 0.18922766, -0.5277673, -1.3090396],
  58. [1.8914849, -0.97001046, -0.23726775, 0.00525892],
  59. [-1.0271876, 0.33526883, 1.7413973, 0.12313101]]]).astype(np.float32)
  60. out_ms = audio.TimeMasking(False, 2, 0, 0)(spectrogram)
  61. out_benchmark = np.array([[[0., 0., 0.07162686, -0.45436913],
  62. [0., 0., 0.62333095, -0.09532598],
  63. [0., 0., -0.58152324, -0.00221091]],
  64. [[0., 0., -0.5277673, -1.3090396],
  65. [0., 0., -0.23726775, 0.00525892],
  66. [0., 0., 1.7413973, 0.12313101]]]).astype(np.float32)
  67. allclose_nparray(out_ms, out_benchmark, 0.0001, 0.0001)
  68. def test_func_time_masking_pipeline():
  69. """ mindspore pipeline mode normal testcase:time_masking op"""
  70. logger.info("test time_masking op, pipeline")
  71. generator = gen([CHANNEL, FREQ, TIME])
  72. data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"])
  73. transforms = [audio.TimeMasking(True, 8)]
  74. data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
  75. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  76. out_put = item["multi_dimensional_data"]
  77. assert out_put.shape == (CHANNEL, FREQ, TIME)
  78. def test_time_masking_invalid_input():
  79. def test_invalid_param(test_name, iid_masks, time_mask_param, mask_start, error, error_msg):
  80. logger.info("Test TimeMasking with wrong params: {0}".format(test_name))
  81. with pytest.raises(error) as error_info:
  82. audio.TimeMasking(iid_masks, time_mask_param, mask_start)
  83. assert error_msg in str(error_info.value)
  84. def test_invalid_input(test_name, iid_masks, time_mask_param, mask_start, error, error_msg):
  85. logger.info("Test TimeMasking with wrong params: {0}".format(test_name))
  86. with pytest.raises(error) as error_info:
  87. spectrogram = next(gen((CHANNEL, FREQ, TIME)))[0]
  88. _ = audio.TimeMasking(iid_masks, time_mask_param, mask_start)(spectrogram)
  89. assert error_msg in str(error_info.value)
  90. test_invalid_param("invalid mask_start", True, 2, -10, ValueError,
  91. "Input mask_start is not within the required interval of [0, 16777216].")
  92. test_invalid_param("invalid mask_param", True, -2, 10, ValueError,
  93. "Input mask_param is not within the required interval of [0, 16777216].")
  94. test_invalid_param("invalid iid_masks", "True", 2, 10, TypeError,
  95. "Argument iid_masks with value True is not of type [<class 'bool'>], but got <class 'str'>.")
  96. test_invalid_input("invalid mask_start", False, 2, 100, RuntimeError,
  97. "MaskAlongAxis: mask_start should be less than the length of chosen dimension.")
  98. test_invalid_input("invalid mask_width", False, 200, 2, RuntimeError,
  99. "TimeMasking: time_mask_param should be less than or equal to the length of time dimension.")
  100. if __name__ == "__main__":
  101. test_func_time_masking_eager_random_input()
  102. test_func_time_masking_eager_precision()
  103. test_func_time_masking_pipeline()
  104. test_time_masking_invalid_input()