You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_angle.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. import mindspore.dataset.audio.transforms as a_c_trans
  19. def count_unequal_element(data_expected, data_me, rtol, atol):
  20. assert data_expected.shape == data_me.shape
  21. total_count = len(data_expected.flatten())
  22. error = np.abs(data_expected - data_me)
  23. greater = np.greater(error, atol + np.abs(data_expected) * rtol)
  24. loss_count = np.count_nonzero(greater)
  25. assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
  26. data_expected[greater], data_me[greater], error[greater])
  27. def test_func_angle_001():
  28. """
  29. Eager Test
  30. """
  31. arr = np.array([[73.04, -13.00], [57.49, 13.20], [-57.64, 6.51], [-52.25, 30.67], [-30.11, -18.34],
  32. [-63.32, 99.33], [95.82, -24.76]], dtype=np.double)
  33. expected = np.array([-0.17614017, 0.22569334, 3.02912684, 2.6107975, -2.59450886, 2.13831337, -0.25286988],
  34. dtype=np.double)
  35. angle_op = a_c_trans.Angle()
  36. output = angle_op(arr)
  37. count_unequal_element(expected, output, 0.0001, 0.0001)
  38. def test_func_angle_002():
  39. """
  40. Pipeline Test
  41. """
  42. np.random.seed(6)
  43. arr = np.array([[[84.25, -85.92], [-92.23, 23.06], [-7.33, -44.17], [-62.95, -14.73]],
  44. [[93.09, 38.18], [-81.94, 71.34], [71.33, -39.00], [95.25, -32.94]]], dtype=np.double)
  45. expected = np.array([[-0.79521156, 2.89658848, -1.73524737, -2.91173309],
  46. [0.3892177, 2.42523905, -0.50034807, -0.33295219]], dtype=np.double)
  47. label = np.random.sample((2, 4, 1))
  48. data = (arr, label)
  49. dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
  50. angle_op = a_c_trans.Angle()
  51. dataset = dataset.map(operations=angle_op, input_columns=["col1"])
  52. for item1, item2 in zip(dataset.create_dict_iterator(num_epochs=1, output_numpy=True), expected):
  53. count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
  54. def test_func_angle_003():
  55. """
  56. Pipeline Error Test
  57. """
  58. np.random.seed(78)
  59. arr = np.array([["11", "22"], ["33", "44"], ["55", "66"], ["77", "88"]])
  60. label = np.random.sample((4, 1))
  61. data = (arr, label)
  62. dataset = ds.NumpySlicesDataset(data, column_names=["col1", 'col2'], shuffle=False)
  63. angle_op = a_c_trans.Angle()
  64. dataset = dataset.map(operations=angle_op, input_columns=["col1"])
  65. num_itr = 0
  66. with pytest.raises(RuntimeError, match="input tensor type should be int, float or double"):
  67. for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  68. num_itr += 1
  69. if __name__ == "__main__":
  70. test_func_angle_001()
  71. test_func_angle_002()
  72. test_func_angle_003()