You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_math_ops.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import pytest
  2. import numpy as onp
  3. import mindspore.context as context
  4. import mindspore.numpy as mnp
  5. def rand_int(*shape):
  6. """return an random integer array with parameter shape"""
  7. res = onp.random.randint(low=1, high=5, size=shape)
  8. if isinstance(res, onp.ndarray):
  9. res = res.astype(onp.float32)
  10. return res
  11. class Cases():
  12. def __init__(self):
  13. self.device_cpu = context.get_context('device_target') == 'CPU'
  14. self.arrs = [
  15. rand_int(2),
  16. rand_int(2, 3),
  17. rand_int(2, 3, 4),
  18. rand_int(2, 3, 4, 5),
  19. ]
  20. # scalars expanded across the 0th dimension
  21. self.scalars = [
  22. rand_int(),
  23. rand_int(1),
  24. rand_int(1, 1),
  25. rand_int(1, 1, 1),
  26. ]
  27. # arrays with last dimension aligned
  28. self.aligned_arrs = [
  29. rand_int(2, 3),
  30. rand_int(1, 4, 3),
  31. rand_int(5, 1, 2, 3),
  32. rand_int(4, 2, 1, 1, 3),
  33. ]
  34. test_case = Cases()
  35. def mnp_inner(a, b):
  36. return mnp.inner(a, b)
  37. def onp_inner(a, b):
  38. return onp.inner(a, b)
  39. def test_inner():
  40. for arr1 in test_case.aligned_arrs:
  41. for arr2 in test_case.aligned_arrs:
  42. match_res(mnp_inner, onp_inner, arr1, arr2)
  43. for scalar1 in test_case.scalars:
  44. for scalar2 in test_case.scalars:
  45. match_res(mnp_inner, onp_inner,
  46. scalar1, scalar2)
  47. # check if the output from mnp function and onp function applied on the arrays are matched
  48. def match_res(mnp_fn, onp_fn, arr1, arr2):
  49. actual = mnp_fn(mnp.asarray(arr1, dtype='float32'),
  50. mnp.asarray(arr2, dtype='float32')).asnumpy()
  51. expected = onp_fn(arr1, arr2)
  52. match_array(actual, expected)
  53. def match_array(actual, expected, error=5):
  54. if error > 0:
  55. onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
  56. decimal=error)
  57. else:
  58. onp.testing.assert_equal(actual.tolist(), expected.tolist())
  59. def test_exception_innner():
  60. with pytest.raises(ValueError):
  61. mnp.inner(mnp.asarray(test_case.arrs[0]),
  62. mnp.asarray(test_case.arrs[1]))