You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mask_along_axis.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import copy
  16. import numpy as np
  17. import pytest
  18. import mindspore.dataset as ds
  19. import mindspore.dataset.audio.transforms as atf
  20. from mindspore import log as logger
  21. CHANNEL = 1
  22. FREQ = 5
  23. TIME = 5
  24. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  25. """
  26. Precision calculation formula
  27. """
  28. if np.any(np.isnan(data_expected)):
  29. assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
  30. elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
  31. count_unequal_element(data_expected, data_me, rtol, atol)
  32. def count_unequal_element(data_expected, data_me, rtol, atol):
  33. """
  34. Precision calculation func
  35. """
  36. assert data_expected.shape == data_me.shape
  37. total_count = len(data_expected.flatten())
  38. error = np.abs(data_expected - data_me)
  39. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  40. loss_count = np.count_nonzero(greater)
  41. assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
  42. data_expected[greater], data_me[greater], error[greater])
  43. def gen(shape):
  44. np.random.seed(0)
  45. data = np.random.random(shape)
  46. yield(np.array(data, dtype=np.float32),)
  47. def test_mask_along_axis_eager_random_input():
  48. """
  49. Feature: MaskAlongAxis
  50. Description: mindspore eager mode normal testcase with random input tensor
  51. Expectation: the returned result is as expected
  52. """
  53. logger.info("test Mask_Along_axis op")
  54. spectrogram = next(gen((CHANNEL, FREQ, TIME)))[0]
  55. expect_output = copy.deepcopy(spectrogram)
  56. out_put = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=5.0, axis=2)(spectrogram)
  57. for item in expect_output[0]:
  58. item[0] = 5.0
  59. assert out_put.shape == (CHANNEL, FREQ, TIME)
  60. allclose_nparray(out_put, expect_output, 0.0001, 0.0001)
  61. def test_mask_along_axis_eager_precision():
  62. """
  63. Feature: MaskAlongAxis
  64. Description: mindspore eager mode checking precision
  65. Expectation: the returned result is as expected
  66. """
  67. logger.info("test MaskAlongAxis op, checking precision")
  68. spectrogram_0 = np.array([[[-0.0635, -0.6903],
  69. [-1.7175, -0.0815],
  70. [0.7981, -0.8297],
  71. [-0.4589, -0.7506]],
  72. [[0.6189, 1.1874],
  73. [0.1856, -0.5536],
  74. [1.0620, 0.2071],
  75. [-0.3874, 0.0664]]]).astype(np.float32)
  76. out_ms_0 = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=2.0, axis=2)(spectrogram_0)
  77. spectrogram_1 = np.array([[[-0.0635, -0.6903],
  78. [-1.7175, -0.0815],
  79. [0.7981, -0.8297],
  80. [-0.4589, -0.7506]],
  81. [[0.6189, 1.1874],
  82. [0.1856, -0.5536],
  83. [1.0620, 0.2071],
  84. [-0.3874, 0.0664]]]).astype(np.float64)
  85. out_ms_1 = atf.MaskAlongAxis(mask_start=0, mask_width=1, mask_value=2.0, axis=2)(spectrogram_1)
  86. out_benchmark = np.array([[[2.0000, -0.6903],
  87. [2.0000, -0.0815],
  88. [2.0000, -0.8297],
  89. [2.0000, -0.7506]],
  90. [[2.0000, 1.1874],
  91. [2.0000, -0.5536],
  92. [2.0000, 0.2071],
  93. [2.0000, 0.0664]]]).astype(np.float32)
  94. allclose_nparray(out_ms_0, out_benchmark, 0.0001, 0.0001)
  95. allclose_nparray(out_ms_1, out_benchmark, 0.0001, 0.0001)
  96. def test_mask_along_axis_pipeline():
  97. """
  98. Feature: MaskAlongAxis
  99. Description: mindspore pipeline mode normal testcase
  100. Expectation: the returned result is as expected
  101. """
  102. logger.info("test MaskAlongAxis op, pipeline")
  103. generator = gen((CHANNEL, FREQ, TIME))
  104. expect_output = copy.deepcopy(next(gen((CHANNEL, FREQ, TIME)))[0])
  105. data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"])
  106. transforms = [atf.MaskAlongAxis(mask_start=2, mask_width=2, mask_value=2.0, axis=2)]
  107. data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
  108. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  109. out_put = item["multi_dimensional_data"]
  110. for item in expect_output[0]:
  111. item[2] = 2.0
  112. item[3] = 2.0
  113. assert out_put.shape == (CHANNEL, FREQ, TIME)
  114. allclose_nparray(out_put, expect_output, 0.0001, 0.0001)
  115. def test_mask_along_axis_invalid_input():
  116. """
  117. Feature: MaskAlongAxis
  118. Description: mindspore eager mode with invalid input tensor
  119. Expectation: throw correct error and message
  120. """
  121. def test_invalid_param(test_name, mask_start, mask_width, mask_value, axis, error, error_msg):
  122. """
  123. a function used for checking correct error and message with various input
  124. """
  125. logger.info("Test MaskAlongAxis with wrong params: {0}".format(test_name))
  126. with pytest.raises(error) as error_info:
  127. atf.MaskAlongAxis(mask_start, mask_width, mask_value, axis)
  128. assert error_msg in str(error_info.value)
  129. test_invalid_param("invalid mask_start", -1, 10, 1.0, 1, ValueError,
  130. "Input mask_start is not within the required interval of [0, 2147483647].")
  131. test_invalid_param("invalid mask_width", 0, -1, 1.0, 1, ValueError,
  132. "Input mask_width is not within the required interval of [1, 2147483647].")
  133. test_invalid_param("invalid axis", 0, 10, 1.0, 1.0, TypeError,
  134. "Argument axis with value 1.0 is not of type [<class 'int'>], but got <class 'float'>.")
  135. test_invalid_param("invalid axis", 0, 10, 1.0, 0, ValueError,
  136. "Input axis is not within the required interval of [1, 2].")
  137. test_invalid_param("invalid axis", 0, 10, 1.0, 3, ValueError,
  138. "Input axis is not within the required interval of [1, 2].")
  139. test_invalid_param("invalid axis", 0, 10, 1.0, -1, ValueError,
  140. "Input axis is not within the required interval of [1, 2].")
  141. if __name__ == "__main__":
  142. test_mask_along_axis_eager_random_input()
  143. test_mask_along_axis_eager_precision()
  144. test_mask_along_axis_pipeline()
  145. test_mask_along_axis_invalid_input()