You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_gathernd_op.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. from mindspore import Tensor
  18. from mindspore.ops import operations as P
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. class GatherNdNet(nn.Cell):
  22. def __init__(self):
  23. super(GatherNdNet, self).__init__()
  24. self.gathernd = P.GatherNd()
  25. def construct(self, x, indices):
  26. return self.gathernd(x, indices)
  27. def gathernd0(nptype):
  28. x = Tensor(np.arange(3 * 2, dtype=nptype).reshape(3, 2))
  29. indices = Tensor(np.array([[1, 1], [0, 1]]).astype(np.int32))
  30. expect = np.array([3, 1]).astype(nptype)
  31. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  32. gathernd = GatherNdNet()
  33. output = gathernd(x, indices)
  34. assert np.array_equal(output.asnumpy(), expect)
  35. @pytest.mark.level0
  36. @pytest.mark.platform_x86_gpu_training
  37. @pytest.mark.env_onecard
  38. def test_gathernd0_float64():
  39. gathernd0(np.float64)
  40. @pytest.mark.level0
  41. @pytest.mark.platform_x86_gpu_training
  42. @pytest.mark.env_onecard
  43. def test_gathernd0_float32():
  44. gathernd0(np.float32)
  45. @pytest.mark.level0
  46. @pytest.mark.platform_x86_gpu_training
  47. @pytest.mark.env_onecard
  48. def test_gathernd0_float16():
  49. gathernd0(np.float16)
  50. @pytest.mark.level0
  51. @pytest.mark.platform_x86_gpu_training
  52. @pytest.mark.env_onecard
  53. def test_gathernd0_int32():
  54. gathernd0(np.int32)
  55. @pytest.mark.level0
  56. @pytest.mark.platform_x86_gpu_training
  57. @pytest.mark.env_onecard
  58. def test_gathernd0_int16():
  59. gathernd0(np.int16)
  60. @pytest.mark.level0
  61. @pytest.mark.platform_x86_gpu_training
  62. @pytest.mark.env_onecard
  63. def test_gathernd0_uint8():
  64. gathernd0(np.uint8)
  65. def gathernd1(nptype):
  66. x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=nptype).reshape(2, 3, 4, 5))
  67. indices = Tensor(np.array([[[[[l, k, j, i] for i in [1, 3, 4]] for j in range(4)]
  68. for k in range(3)] for l in range(2)], dtype='i4'))
  69. expect = np.array([[[[1., 3., 4.],
  70. [6., 8., 9.],
  71. [11., 13., 14.],
  72. [16., 18., 19.]],
  73. [[21., 23., 24.],
  74. [26., 28., 29.],
  75. [31., 33., 34.],
  76. [36., 38., 39.]],
  77. [[41., 43., 44.],
  78. [46., 48., 49.],
  79. [51., 53., 54.],
  80. [56., 58., 59.]]],
  81. [[[61., 63., 64.],
  82. [66., 68., 69.],
  83. [71., 73., 74.],
  84. [76., 78., 79.]],
  85. [[81., 83., 84.],
  86. [86., 88., 89.],
  87. [91., 93., 94.],
  88. [96., 98., 99.]],
  89. [[101., 103., 104.],
  90. [106., 108., 109.],
  91. [111., 113., 114.],
  92. [116., 118., 119.]]]]).astype(nptype)
  93. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  94. gather = GatherNdNet()
  95. output = gather(x, indices)
  96. assert np.array_equal(output.asnumpy(), expect)
  97. @pytest.mark.level0
  98. @pytest.mark.platform_x86_gpu_training
  99. @pytest.mark.env_onecard
  100. def test_gathernd1_float64():
  101. gathernd1(np.float64)
  102. @pytest.mark.level0
  103. @pytest.mark.platform_x86_gpu_training
  104. @pytest.mark.env_onecard
  105. def test_gathernd1_float32():
  106. gathernd1(np.float32)
  107. @pytest.mark.level0
  108. @pytest.mark.platform_x86_gpu_training
  109. @pytest.mark.env_onecard
  110. def test_gathernd1_float16():
  111. gathernd1(np.float16)
  112. @pytest.mark.level0
  113. @pytest.mark.platform_x86_gpu_training
  114. @pytest.mark.env_onecard
  115. def test_gathernd1_int32():
  116. gathernd1(np.int32)
  117. @pytest.mark.level0
  118. @pytest.mark.platform_x86_gpu_training
  119. @pytest.mark.env_onecard
  120. def test_gathernd1_int16():
  121. gathernd1(np.int16)
  122. @pytest.mark.level0
  123. @pytest.mark.platform_x86_gpu_training
  124. @pytest.mark.env_onecard
  125. def test_gathernd1_uint8():
  126. gathernd1(np.uint8)
  127. def gathernd2(nptype):
  128. x = Tensor(np.array([[4., 5., 4., 1., 5.],
  129. [4., 9., 5., 6., 4.],
  130. [9., 8., 4., 3., 6.],
  131. [0., 4., 2., 2., 8.],
  132. [1., 8., 6., 2., 8.],
  133. [8., 1., 9., 7., 3.],
  134. [7., 9., 2., 5., 7.],
  135. [9., 8., 6., 8., 5.],
  136. [3., 7., 2., 7., 4.],
  137. [4., 2., 8., 2., 9.]]).astype(np.float16))
  138. indices = Tensor(np.array([[4000], [1], [300000]]).astype(np.int32))
  139. expect = np.array([[0., 0., 0., 0., 0.],
  140. [4., 9., 5., 6., 4.],
  141. [0., 0., 0., 0., 0.]])
  142. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  143. gathernd = GatherNdNet()
  144. output = gathernd(x, indices)
  145. assert np.array_equal(output.asnumpy(), expect)
  146. @pytest.mark.level0
  147. @pytest.mark.platform_x86_gpu_training
  148. @pytest.mark.env_onecard
  149. def test_gathernd2_float64():
  150. gathernd2(np.float64)
  151. @pytest.mark.level0
  152. @pytest.mark.platform_x86_gpu_training
  153. @pytest.mark.env_onecard
  154. def test_gathernd2_float32():
  155. gathernd2(np.float32)
  156. @pytest.mark.level0
  157. @pytest.mark.platform_x86_gpu_training
  158. @pytest.mark.env_onecard
  159. def test_gathernd2_float16():
  160. gathernd2(np.float16)
  161. @pytest.mark.level0
  162. @pytest.mark.platform_x86_gpu_training
  163. @pytest.mark.env_onecard
  164. def test_gathernd2_int32():
  165. gathernd2(np.int32)
  166. @pytest.mark.level0
  167. @pytest.mark.platform_x86_gpu_training
  168. @pytest.mark.env_onecard
  169. def test_gathernd2_int16():
  170. gathernd2(np.int16)
  171. @pytest.mark.level0
  172. @pytest.mark.platform_x86_gpu_training
  173. @pytest.mark.env_onecard
  174. def test_gathernd2_uint8():
  175. gathernd2(np.uint8)
  176. @pytest.mark.level0
  177. @pytest.mark.platform_x86_gpu_training
  178. @pytest.mark.env_onecard
  179. def test_gathernd_bool():
  180. x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool))
  181. indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int32))
  182. expect = np.array([True, False, False, False]).astype(np.bool)
  183. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  184. gathernd = GatherNdNet()
  185. output = gathernd(x, indices)
  186. assert np.array_equal(output.asnumpy(), expect)
  187. @pytest.mark.level0
  188. @pytest.mark.platform_x86_gpu_training
  189. @pytest.mark.env_onecard
  190. def test_gathernd_indices_int64():
  191. x = Tensor(np.array([[True, False], [False, False]]).astype(np.bool))
  192. indices = Tensor(np.array([[0, 0], [0, 1], [1, 0], [1, 1]]).astype(np.int64))
  193. expect = np.array([True, False, False, False]).astype(np.bool)
  194. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  195. gathernd = GatherNdNet()
  196. output = gathernd(x, indices)
  197. assert np.array_equal(output.asnumpy(), expect)