You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """utility functions for mindspore.numpy st tests"""
  16. import functools
  17. import numpy as onp
  18. from mindspore import Tensor
  19. import mindspore.numpy as mnp
  20. def match_array(actual, expected, error=0):
  21. if isinstance(actual, int):
  22. actual = onp.asarray(actual)
  23. if isinstance(expected, int):
  24. expected = onp.asarray(expected)
  25. if error > 0:
  26. onp.testing.assert_almost_equal(actual.tolist(), expected.tolist(),
  27. decimal=error)
  28. else:
  29. onp.testing.assert_equal(actual.tolist(), expected.tolist())
  30. def check_all_results(onp_results, mnp_results, error=0):
  31. """Check all results from numpy and mindspore.numpy"""
  32. for i, _ in enumerate(onp_results):
  33. match_array(onp_results[i], mnp_results[i].asnumpy())
  34. def check_all_unique_results(onp_results, mnp_results):
  35. """
  36. Check all results from numpy and mindspore.numpy.
  37. Args:
  38. onp_results (Union[tuple of numpy.arrays, numpy.array])
  39. mnp_results (Union[tuple of Tensors, Tensor])
  40. """
  41. for i, _ in enumerate(onp_results):
  42. if isinstance(onp_results[i], tuple):
  43. for j in range(len(onp_results[i])):
  44. match_array(onp_results[i][j],
  45. mnp_results[i][j].asnumpy(), error=7)
  46. else:
  47. match_array(onp_results[i], mnp_results[i].asnumpy(), error=7)
  48. def run_non_kw_test(mnp_fn, onp_fn, test_case):
  49. """Run tests on functions with non keyword arguments"""
  50. for i in range(len(test_case.arrs)):
  51. arrs = test_case.arrs[:i]
  52. match_res(mnp_fn, onp_fn, *arrs)
  53. for i in range(len(test_case.scalars)):
  54. arrs = test_case.scalars[:i]
  55. match_res(mnp_fn, onp_fn, *arrs)
  56. for i in range(len(test_case.expanded_arrs)):
  57. arrs = test_case.expanded_arrs[:i]
  58. match_res(mnp_fn, onp_fn, *arrs)
  59. for i in range(len(test_case.nested_arrs)):
  60. arrs = test_case.nested_arrs[:i]
  61. match_res(mnp_fn, onp_fn, *arrs)
  62. def rand_int(*shape):
  63. """return an random integer array with parameter shape"""
  64. res = onp.random.randint(low=1, high=5, size=shape)
  65. if isinstance(res, onp.ndarray):
  66. return res.astype(onp.float32)
  67. return float(res)
  68. # return an random boolean array
  69. def rand_bool(*shape):
  70. return onp.random.rand(*shape) > 0.5
  71. def match_res(mnp_fn, onp_fn, *arrs, **kwargs):
  72. """Checks results from applying mnp_fn and onp_fn on arrs respectively"""
  73. dtype = kwargs.get('dtype', mnp.float32)
  74. kwargs.pop('dtype', None)
  75. mnp_arrs = map(functools.partial(Tensor, dtype=dtype), arrs)
  76. error = kwargs.get('error', 0)
  77. kwargs.pop('error', None)
  78. mnp_res = mnp_fn(*mnp_arrs, **kwargs)
  79. onp_res = onp_fn(*arrs, **kwargs)
  80. match_all_arrays(mnp_res, onp_res, error=error)
  81. def match_all_arrays(mnp_res, onp_res, error=0):
  82. if isinstance(mnp_res, (tuple, list)):
  83. assert len(mnp_res) == len(onp_res)
  84. for actual, expected in zip(mnp_res, onp_res):
  85. match_array(actual.asnumpy(), expected, error)
  86. else:
  87. match_array(mnp_res.asnumpy(), onp_res, error)
  88. def match_meta(actual, expected):
  89. # float64 and int64 are not supported, and the default type for
  90. # float and int are float32 and int32, respectively
  91. if expected.dtype == onp.float64:
  92. expected = expected.astype(onp.float32)
  93. elif expected.dtype == onp.int64:
  94. expected = expected.astype(onp.int32)
  95. assert actual.shape == expected.shape
  96. assert actual.dtype == expected.dtype
  97. def run_binop_test(mnp_fn, onp_fn, test_case, error=0):
  98. for arr in test_case.arrs:
  99. match_res(mnp_fn, onp_fn, arr, arr, error=error)
  100. for scalar in test_case.scalars:
  101. match_res(mnp_fn, onp_fn, arr, scalar, error=error)
  102. match_res(mnp_fn, onp_fn, scalar, arr, error=error)
  103. for scalar1 in test_case.scalars:
  104. for scalar2 in test_case.scalars:
  105. match_res(mnp_fn, onp_fn, scalar1, scalar2, error=error)
  106. for expanded_arr1 in test_case.expanded_arrs:
  107. for expanded_arr2 in test_case.expanded_arrs:
  108. match_res(mnp_fn, onp_fn, expanded_arr1, expanded_arr2, error=error)
  109. for broadcastable1 in test_case.broadcastables:
  110. for broadcastable2 in test_case.broadcastables:
  111. match_res(mnp_fn, onp_fn, broadcastable1, broadcastable2, error=error)
  112. def run_unary_test(mnp_fn, onp_fn, test_case, error=0):
  113. for arr in test_case.arrs:
  114. match_res(mnp_fn, onp_fn, arr, error=error)
  115. for arr in test_case.scalars:
  116. match_res(mnp_fn, onp_fn, arr, error=error)
  117. for arr in test_case.expanded_arrs:
  118. match_res(mnp_fn, onp_fn, arr, error=error)
  119. def run_multi_test(mnp_fn, onp_fn, arrs, error=0):
  120. mnp_arrs = map(Tensor, arrs)
  121. for actual, expected in zip(mnp_fn(*mnp_arrs), onp_fn(*arrs)):
  122. match_all_arrays(actual, expected, error)
  123. def run_single_test(mnp_fn, onp_fn, arr, error=0):
  124. mnp_arr = Tensor(arr)
  125. for actual, expected in zip(mnp_fn(mnp_arr), onp_fn(arr)):
  126. if isinstance(expected, tuple):
  127. for actual_arr, expected_arr in zip(actual, expected):
  128. match_array(actual_arr.asnumpy(), expected_arr, error)
  129. match_array(actual.asnumpy(), expected, error)
  130. def run_logical_test(mnp_fn, onp_fn, test_case):
  131. for x1 in test_case.boolean_arrs:
  132. for x2 in test_case.boolean_arrs:
  133. match_res(mnp_fn, onp_fn, x1, x2, dtype=mnp.bool_)
  134. def to_tensor(obj, dtype=None):
  135. if dtype is None:
  136. res = Tensor(obj)
  137. if res.dtype == mnp.float64:
  138. res = res.astype(mnp.float32)
  139. if res.dtype == mnp.int64:
  140. res = res.astype(mnp.int32)
  141. else:
  142. res = Tensor(obj, dtype)
  143. return res