You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_math_ops.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """unit tests for numpy math operations"""
  16. import pytest
  17. import numpy as onp
  18. import mindspore.numpy as mnp
  19. def rand_int(*shape):
  20. """return an random integer array with parameter shape"""
  21. res = onp.random.randint(low=1, high=5, size=shape)
  22. if isinstance(res, onp.ndarray):
  23. res = res.astype(onp.float32)
  24. return res
  25. class Cases():
  26. def __init__(self):
  27. self.arrs = [
  28. rand_int(2),
  29. rand_int(2, 3),
  30. rand_int(2, 3, 4),
  31. rand_int(2, 3, 4, 5),
  32. ]
  33. # scalars expanded across the 0th dimension
  34. self.scalars = [
  35. rand_int(),
  36. rand_int(1),
  37. rand_int(1, 1),
  38. rand_int(1, 1, 1),
  39. ]
  40. # arrays with last dimension aligned
  41. self.aligned_arrs = [
  42. rand_int(2, 3),
  43. rand_int(1, 4, 3),
  44. rand_int(5, 1, 2, 3),
  45. rand_int(4, 2, 1, 1, 3),
  46. ]
  47. test_case = Cases()
  48. def mnp_inner(a, b):
  49. return mnp.inner(a, b)
  50. def onp_inner(a, b):
  51. return onp.inner(a, b)
  52. def test_inner():
  53. for arr1 in test_case.aligned_arrs:
  54. for arr2 in test_case.aligned_arrs:
  55. match_res(mnp_inner, onp_inner, arr1, arr2)
  56. for scalar1 in test_case.scalars:
  57. for scalar2 in test_case.scalars:
  58. match_res(mnp_inner, onp_inner,
  59. scalar1, scalar2)
  60. # check if the output from mnp function and onp function applied on the arrays are matched
  61. def match_res(mnp_fn, onp_fn, arr1, arr2):
  62. actual = mnp_fn(mnp.asarray(arr1, dtype='float32'),
  63. mnp.asarray(arr2, dtype='float32')).asnumpy()
  64. expected = onp_fn(arr1, arr2)
  65. match_array(actual, expected)
  66. def match_array(actual, expected, error=5):
  67. if error > 0:
  68. onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
  69. decimal=error)
  70. else:
  71. onp.testing.assert_equal(actual.tolist(), expected.tolist())
  72. def test_exception_innner():
  73. with pytest.raises(ValueError):
  74. mnp.inner(mnp.asarray(test_case.arrs[0]),
  75. mnp.asarray(test_case.arrs[1]))