You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_angle.py 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.audio.transforms as a_c_trans
  19. def _count_unequal_element(data_expected, data_me, rtol, atol):
  20. assert data_expected.shape == data_me.shape
  21. total_count = len(data_expected.flatten())
  22. error = np.abs(data_expected - data_me)
  23. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  24. loss_count = np.count_nonzero(greater)
  25. assert (loss_count / total_count) < rtol, \
  26. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
  27. format(data_expected[greater], data_me[greater], error[greater])
  28. def test_func_angle_001():
  29. """
  30. Eager Test
  31. """
  32. arr = np.array([[73.04, -13.00], [57.49, 13.20], [-57.64, 6.51], [-52.25, 30.67], [-30.11, -18.34], \
  33. [-63.32, 99.33], [95.82, -24.76]], dtype=np.double)
  34. expected = np.array([-0.17614017, 0.22569334, 3.02912684, 2.6107975, -2.59450886, 2.13831337, -0.25286988], \
  35. dtype=np.double)
  36. angle_op = a_c_trans.Angle()
  37. output = angle_op(arr)
  38. _count_unequal_element(expected, output, 0.0001, 0.0001)
  39. def test_func_angle_002():
  40. """
  41. Pipeline Test
  42. """
  43. np.random.seed(6)
  44. arr = np.array([[[84.25, -85.92], [-92.23, 23.06], [-7.33, -44.17], [-62.95, -14.73]], \
  45. [[93.09, 38.18], [-81.94, 71.34], [71.33, -39.00], [95.25, -32.94]]], dtype=np.double)
  46. expected = np.array([[-0.79521156, 2.89658848, -1.73524737, -2.91173309], \
  47. [0.3892177, 2.42523905, -0.50034807, -0.33295219]], dtype=np.double)
  48. label = np.random.sample((2, 4, 1))
  49. data = (arr, label)
  50. dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
  51. angle_op = a_c_trans.Angle()
  52. dataset = dataset.map(operations=angle_op, input_columns=["col1"])
  53. for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected):
  54. _count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
  55. def test_func_angle_003():
  56. """
  57. Pipeline Error Test
  58. """
  59. np.random.seed(78)
  60. arr = np.array([["11", "22"], ["33", "44"], ["55", "66"], ["77", "88"]])
  61. label = np.random.sample((4, 1))
  62. data = (arr, label)
  63. dataset = ds.NumpySlicesDataset(data, column_names=["col1", 'col2'], shuffle=False)
  64. angle_op = a_c_trans.Angle()
  65. dataset = dataset.map(operations=angle_op, input_columns=["col1"])
  66. num_itr = 0
  67. with pytest.raises(RuntimeError, match="The input type should be numbers"):
  68. for _ in dataset.create_dict_iterator(output_numpy=True):
  69. num_itr += 1
  70. if __name__ == "__main__":
  71. test_func_angle_001()
  72. test_func_angle_002()
  73. test_func_angle_003()