You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_scatter_update_op.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from mindspore.ops import operations as P
  21. from mindspore.ops.operations import _inner_ops as inner
  22. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  23. # all cases tested against dchip
  24. class TestScatterUpdateNet(nn.Cell):
  25. def __init__(self, inputx, indices, updates):
  26. super(TestScatterUpdateNet, self).__init__()
  27. self.scatter_update = P.ScatterUpdate()
  28. self.inputx = Parameter(inputx, name="inputx")
  29. self.indices = Parameter(indices, name="indices")
  30. self.updates = Parameter(updates, name="updates")
  31. def construct(self):
  32. out = self.scatter_update(self.inputx, self.indices, self.updates)
  33. return out
  34. def scatter_update_net(inputx, indices, updates):
  35. net = TestScatterUpdateNet(inputx, indices, updates)
  36. return net()
  37. class TestScatterUpdateDynamicNet(nn.Cell):
  38. def __init__(self, inputx, indices, updates):
  39. super(TestScatterUpdateDynamicNet, self).__init__()
  40. self.scatter_update = P.ScatterUpdate()
  41. self.test_dynamic = inner.GpuConvertToDynamicShape()
  42. self.inputx = Parameter(inputx, name="inputx")
  43. self.indices = Parameter(indices, name="indices")
  44. self.updates = Parameter(updates, name="updates")
  45. def construct(self):
  46. indices = self.test_dynamic(self.indices)
  47. updates = self.test_dynamic(self.updates)
  48. out = self.scatter_update(self.inputx, indices, updates)
  49. return out
  50. def scatter_update_d_net(inputx, indices, updates):
  51. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  52. net = TestScatterUpdateDynamicNet(inputx, indices, updates)
  53. return net()
  54. class TestScatterUpdateDynamicNet2(nn.Cell):
  55. def __init__(self, inputx):
  56. super(TestScatterUpdateDynamicNet2, self).__init__()
  57. self.scatter_update = P.ScatterUpdate()
  58. self.test_dynamic = inner.GpuConvertToDynamicShape()
  59. self.inputx = Parameter(inputx, name="inputx")
  60. def construct(self, indices, updates):
  61. indices = self.test_dynamic(indices)
  62. updates = self.test_dynamic(updates)
  63. out = self.scatter_update(self.inputx, indices, updates)
  64. return out
  65. def scatter_update_d2_net(inputx, indices_1, updates_1,
  66. indices_2, updates_2):
  67. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  68. net = TestScatterUpdateDynamicNet2(inputx)
  69. out1 = net(indices_1, updates_1)
  70. out2 = net(indices_2, updates_2)
  71. return (out1, out2)
  72. @pytest.mark.level0
  73. @pytest.mark.platform_x86_gpu_training
  74. @pytest.mark.env_onecard
  75. def test_scatter_update_small_float32():
  76. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  77. indices = Tensor(np.array([0, 1]).astype(np.int32))
  78. updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
  79. output = scatter_update_net(inputx, indices, updates)
  80. expected = np.array([[0., 1., 2.],
  81. [3., 4., 5.]])
  82. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  83. @pytest.mark.level0
  84. @pytest.mark.platform_x86_gpu_training
  85. @pytest.mark.env_onecard
  86. def test_scatter_update_input_updated():
  87. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  88. indices = Tensor(np.array([0, 1]).astype(np.int32))
  89. updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
  90. net = TestScatterUpdateNet(inputx, indices, updates)
  91. net()
  92. expected = np.array([[0., 1., 2.],
  93. [3., 4., 5.]])
  94. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  95. @pytest.mark.level0
  96. @pytest.mark.platform_x86_gpu_training
  97. @pytest.mark.env_onecard
  98. def test_scatter_update_input_less_than_1_float32():
  99. inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516],
  100. [0.876542, 0.451611, 0.55112],
  101. [0.111244, 0.633333, 0.34444]]).astype(np.float32))
  102. indices = Tensor(np.array([1, 0, 2]).astype(np.int32))
  103. updates = Tensor(np.arange(34, 43).reshape((3, 3)).astype(np.float32))
  104. output = scatter_update_net(inputx, indices, updates)
  105. expected = np.array([[37., 38., 39.],
  106. [34., 35., 36.],
  107. [40., 41., 42.]], dtype=np.float32)
  108. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  109. @pytest.mark.level0
  110. @pytest.mark.platform_x86_gpu_training
  111. @pytest.mark.env_onecard
  112. def test_scatter_update_float16():
  113. inputx = Tensor(np.zeros((2, 3)).astype(np.float16))
  114. indices = Tensor(np.array([0, 1]).astype(np.int32))
  115. updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float16))
  116. output = scatter_update_net(inputx, indices, updates)
  117. expected = np.array([[0., 1., 2.],
  118. [3., 4., 5.]]).astype(np.float16)
  119. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  120. @pytest.mark.level0
  121. @pytest.mark.platform_x86_gpu_training
  122. @pytest.mark.env_onecard
  123. def test_scatter_update_int32():
  124. inputx = Tensor(np.zeros((2, 3)).astype(np.int32))
  125. indices = Tensor(np.array([0, 1]).astype(np.int32))
  126. updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.int32))
  127. output = scatter_update_net(inputx, indices, updates)
  128. expected = np.array([[0., 1., 2.],
  129. [3., 4., 5.]]).astype(np.int32)
  130. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  131. @pytest.mark.level0
  132. @pytest.mark.platform_x86_gpu_training
  133. @pytest.mark.env_onecard
  134. def test_scatter_update_large_float16():
  135. inputx = Tensor(np.zeros((4, 3)).astype(np.float16))
  136. indices = Tensor(np.array([[2, 1], [0, 3]]).astype(np.int32))
  137. updates = Tensor(np.arange(63, 75).reshape((2, 2, 3)).astype(np.float16))
  138. output = scatter_update_net(inputx, indices, updates)
  139. expected = np.array([[69., 70., 71.],
  140. [66., 67., 68.],
  141. [63., 64., 65.],
  142. [72., 73., 74.]]).astype(np.float16)
  143. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  144. @pytest.mark.level0
  145. @pytest.mark.platform_x86_gpu_training
  146. @pytest.mark.env_onecard
  147. def test_scatter_update_disordered_float16():
  148. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16)))
  149. indices = Tensor(np.array([1, 2]).astype(np.int32))
  150. updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.float16))
  151. output = scatter_update_net(inputx, indices, updates)
  152. expected = np.array([[45., 44., 43., 42.],
  153. [63., 64., 65., 66.],
  154. [67., 68., 69., 70.]]).astype(np.float16)
  155. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  156. @pytest.mark.level0
  157. @pytest.mark.platform_x86_gpu_training
  158. @pytest.mark.env_onecard
  159. def test_scatter_update_disordered_int32():
  160. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
  161. indices = Tensor(np.array([1, 2]).astype(np.int32))
  162. updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int32))
  163. output = scatter_update_net(inputx, indices, updates)
  164. expected = np.array([[45., 44., 43., 42.],
  165. [63., 64., 65., 66.],
  166. [67., 68., 69., 70.]]).astype(np.int32)
  167. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  168. @pytest.mark.level0
  169. @pytest.mark.platform_x86_gpu_training
  170. @pytest.mark.env_onecard
  171. def test_scatter_update_large_shape_float16():
  172. inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.float16))
  173. indices = Tensor(np.array([1, 0]).astype(np.int32))
  174. updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.float16)))
  175. output = scatter_update_net(inputx, indices, updates)
  176. expected = np.array([[[[23., 22., 21., 20.],
  177. [19., 18., 17., 16.],
  178. [15., 14., 13., 12.]],
  179. [[11., 10., 9., 8.],
  180. [7., 6., 5., 4.],
  181. [3., 2., 1., 0.]]],
  182. [[[47., 46., 45., 44.],
  183. [43., 42., 41., 40.],
  184. [39., 38., 37., 36.]],
  185. [[35., 34., 33., 32.],
  186. [31., 30., 29., 28.],
  187. [27., 26., 25., 24.]]],
  188. [[[48., 49., 50., 51.],
  189. [52., 53., 54., 55.],
  190. [56., 57., 58., 59.]],
  191. [[60., 61., 62., 63.],
  192. [64., 65., 66., 67.],
  193. [68., 69., 70., 71.]]],
  194. [[[72., 73., 74., 75.],
  195. [76., 77., 78., 79.],
  196. [80., 81., 82., 83.]],
  197. [[84., 85., 86., 87.],
  198. [88., 89., 90., 91.],
  199. [92., 93., 94., 95.]]]]).astype(np.float16)
  200. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  201. @pytest.mark.level0
  202. @pytest.mark.platform_x86_gpu_training
  203. @pytest.mark.env_onecard
  204. def test_scatter_update_disordered_int8():
  205. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int8)))
  206. indices = Tensor(np.array([1, 2]).astype(np.int32))
  207. updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int8))
  208. output = scatter_update_net(inputx, indices, updates)
  209. expected = np.array([[45., 44., 43., 42.],
  210. [63., 64., 65., 66.],
  211. [67., 68., 69., 70.]]).astype(np.int8)
  212. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  213. @pytest.mark.level0
  214. @pytest.mark.platform_x86_gpu_training
  215. @pytest.mark.env_onecard
  216. def test_scatter_update_large_shape_int8():
  217. inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.int8))
  218. indices = Tensor(np.array([1, 0]).astype(np.int32))
  219. updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.int8)))
  220. output = scatter_update_net(inputx, indices, updates)
  221. expected = np.array([[[[23., 22., 21., 20.],
  222. [19., 18., 17., 16.],
  223. [15., 14., 13., 12.]],
  224. [[11., 10., 9., 8.],
  225. [7., 6., 5., 4.],
  226. [3., 2., 1., 0.]]],
  227. [[[47., 46., 45., 44.],
  228. [43., 42., 41., 40.],
  229. [39., 38., 37., 36.]],
  230. [[35., 34., 33., 32.],
  231. [31., 30., 29., 28.],
  232. [27., 26., 25., 24.]]],
  233. [[[48., 49., 50., 51.],
  234. [52., 53., 54., 55.],
  235. [56., 57., 58., 59.]],
  236. [[60., 61., 62., 63.],
  237. [64., 65., 66., 67.],
  238. [68., 69., 70., 71.]]],
  239. [[[72., 73., 74., 75.],
  240. [76., 77., 78., 79.],
  241. [80., 81., 82., 83.]],
  242. [[84., 85., 86., 87.],
  243. [88., 89., 90., 91.],
  244. [92., 93., 94., 95.]]]]).astype(np.int8)
  245. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  246. @pytest.mark.level0
  247. @pytest.mark.platform_x86_gpu_training
  248. @pytest.mark.env_onecard
  249. def test_scatter_update_large_uint8():
  250. inputx = Tensor(np.zeros((4, 3)).astype(np.uint8))
  251. indices = Tensor(np.array([[2, 1], [0, 3]]).astype(np.int32))
  252. updates = Tensor(np.arange(63, 75).reshape((2, 2, 3)).astype(np.uint8))
  253. output = scatter_update_net(inputx, indices, updates)
  254. expected = np.array([[69., 70., 71.],
  255. [66., 67., 68.],
  256. [63., 64., 65.],
  257. [72., 73., 74.]]).astype(np.uint8)
  258. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  259. @pytest.mark.level0
  260. @pytest.mark.platform_x86_gpu_training
  261. @pytest.mark.env_onecard
  262. def test_scatter_update_disordered_uint8():
  263. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.uint8)))
  264. indices = Tensor(np.array([1, 2]).astype(np.int32))
  265. updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.uint8))
  266. output = scatter_update_net(inputx, indices, updates)
  267. expected = np.array([[45., 44., 43., 42.],
  268. [63., 64., 65., 66.],
  269. [67., 68., 69., 70.]]).astype(np.uint8)
  270. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  271. @pytest.mark.level0
  272. @pytest.mark.platform_x86_gpu_training
  273. @pytest.mark.env_onecard
  274. def test_scatter_update_large_shape_dynamic_int8():
  275. inputx = Tensor(np.arange(96).reshape((4, 2, 3, 4)).astype(np.int8))
  276. indices = Tensor(np.array([1, 0]).astype(np.int32))
  277. updates = Tensor(np.flip(np.arange(48).reshape((2, 2, 3, 4)).astype(np.int8)))
  278. output = scatter_update_d_net(inputx, indices, updates)
  279. expected = np.array([[[[23., 22., 21., 20.],
  280. [19., 18., 17., 16.],
  281. [15., 14., 13., 12.]],
  282. [[11., 10., 9., 8.],
  283. [7., 6., 5., 4.],
  284. [3., 2., 1., 0.]]],
  285. [[[47., 46., 45., 44.],
  286. [43., 42., 41., 40.],
  287. [39., 38., 37., 36.]],
  288. [[35., 34., 33., 32.],
  289. [31., 30., 29., 28.],
  290. [27., 26., 25., 24.]]],
  291. [[[48., 49., 50., 51.],
  292. [52., 53., 54., 55.],
  293. [56., 57., 58., 59.]],
  294. [[60., 61., 62., 63.],
  295. [64., 65., 66., 67.],
  296. [68., 69., 70., 71.]]],
  297. [[[72., 73., 74., 75.],
  298. [76., 77., 78., 79.],
  299. [80., 81., 82., 83.]],
  300. [[84., 85., 86., 87.],
  301. [88., 89., 90., 91.],
  302. [92., 93., 94., 95.]]]]).astype(np.int8)
  303. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  304. @pytest.mark.level0
  305. @pytest.mark.platform_x86_gpu_training
  306. @pytest.mark.env_onecard
  307. def test_scatter_update_disordered_dynamic_int32():
  308. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
  309. indices = Tensor(np.array([1, 2]).astype(np.int32))
  310. updates = Tensor(np.arange(63, 71).reshape((2, 4)).astype(np.int32))
  311. output = scatter_update_d_net(inputx, indices, updates)
  312. expected = np.array([[45., 44., 43., 42.],
  313. [63., 64., 65., 66.],
  314. [67., 68., 69., 70.]]).astype(np.int32)
  315. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  316. @pytest.mark.level0
  317. @pytest.mark.platform_x86_gpu_training
  318. @pytest.mark.env_onecard
  319. def test_scatter_update_two_inputs():
  320. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  321. indices_1 = Tensor(np.array([0, 1]).astype(np.int32))
  322. updates_1 = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
  323. indices_2 = Tensor(np.array([1]).astype(np.int32))
  324. updates_2 = Tensor(np.arange(34, 37).reshape((1, 3)).astype(np.float32))
  325. output_1, output_2 = scatter_update_d2_net(inputx, indices_1, updates_1,
  326. indices_2, updates_2)
  327. expected_1 = np.array([[0., 1., 2.],
  328. [3., 4., 5.]], dtype=np.float32)
  329. expected_2 = np.array([[0., 1., 2.],
  330. [34., 35., 36.]], dtype=np.float32)
  331. np.testing.assert_array_almost_equal(output_1.asnumpy(), expected_1)
  332. np.testing.assert_array_almost_equal(output_2.asnumpy(), expected_2)