You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dc_shift.py 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.audio.transforms as a_c_trans
  19. def count_unequal_element(data_expected, data_me, rtol, atol):
  20. assert data_expected.shape == data_me.shape
  21. total_count = len(data_expected.flatten())
  22. error = np.abs(data_expected - data_me)
  23. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  24. loss_count = np.count_nonzero(greater)
  25. assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
  26. data_expected[greater], data_me[greater], error[greater])
  27. def test_func_dc_shift_eager():
  28. """
  29. Eager Test
  30. """
  31. arr = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93, 0.71, 0.61], dtype=np.double)
  32. expected = np.array([0.0400, 0.0400, -0.0400, -0.2600, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
  33. dtype=np.double)
  34. dcshift_op = a_c_trans.DCShift(1.0, 0.04)
  35. output = dcshift_op(arr)
  36. count_unequal_element(expected, output, 0.0001, 0.0001)
  37. def test_func_dc_shift_pipeline():
  38. """
  39. Pipeline Test
  40. """
  41. arr = np.array([[1.14, -1.06, 0.94, 0.90], [-1.11, 1.40, -0.33, 1.43]], dtype=np.double)
  42. expected = np.array([[0.2300, -0.2600, 0.2300, 0.2300], [-0.3100, 0.2300, 0.4700, 0.2300]], dtype=np.double)
  43. dataset = ds.NumpySlicesDataset(arr, column_names=["col1"], shuffle=False)
  44. dcshift_op = a_c_trans.DCShift(0.8, 0.03)
  45. dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
  46. for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1, output_numpy=True), expected):
  47. count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
  48. def test_func_dc_shift_pipeline_error():
  49. """
  50. Pipeline Error Test
  51. """
  52. arr = np.random.uniform(-2, 2, size=(1000)).astype(np.float)
  53. label = np.random.sample((1000, 1))
  54. data = (arr, label)
  55. dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
  56. num_itr = 0
  57. with pytest.raises(ValueError, match=r"Input shift is not within the required interval of \[-2.0, 2.0\]."):
  58. dcshift_op = a_c_trans.DCShift(2.5, 0.03)
  59. dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
  60. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  61. num_itr += 1
  62. if __name__ == "__main__":
  63. test_func_dc_shift_eager()
  64. test_func_dc_shift_pipeline()
  65. test_func_dc_shift_pipeline_error()