You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_scatter_arithmetic_op.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  22. class TestScatterAddNet(nn.Cell):
  23. def __init__(self, lock, inputx, indices, updates):
  24. super(TestScatterAddNet, self).__init__()
  25. self.scatter_add = P.ScatterAdd(use_locking=lock)
  26. self.inputx = Parameter(inputx, name="inputx")
  27. self.indices = Parameter(indices, name="indices")
  28. self.updates = Parameter(updates, name="updates")
  29. def construct(self):
  30. out = self.scatter_add(self.inputx, self.indices, self.updates)
  31. return out
  32. def scatter_add_net(inputx, indices, updates):
  33. lock = True
  34. net = TestScatterAddNet(lock, inputx, indices, updates)
  35. return net()
  36. def scatter_add_use_locking_false_net(inputx, indices, updates):
  37. lock = False
  38. net = TestScatterAddNet(lock, inputx, indices, updates)
  39. return net()
  40. @pytest.mark.level0
  41. @pytest.mark.platform_x86_cpu
  42. @pytest.mark.env_onecard
  43. def test_scatter_add_small_float32():
  44. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  45. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  46. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  47. output = scatter_add_net(inputx, indices, updates)
  48. expected = np.array([[6., 8., 10.],
  49. [12., 14., 16.]])
  50. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  51. @pytest.mark.level0
  52. @pytest.mark.platform_x86_cpu
  53. @pytest.mark.env_onecard
  54. def test_scatter_add_input_updated():
  55. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  56. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  57. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  58. lock = True
  59. net = TestScatterAddNet(lock, inputx, indices, updates)
  60. net()
  61. expected = np.array([[6., 8., 10.],
  62. [12., 14., 16.]])
  63. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  64. @pytest.mark.level0
  65. @pytest.mark.platform_x86_cpu
  66. @pytest.mark.env_onecard
  67. def test_scatter_add_large_shape_float32():
  68. inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
  69. indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
  70. updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
  71. output = scatter_add_net(inputx, indices, updates)
  72. expected = np.array([[[[1., 2., 3., 4.],
  73. [5., 6., 7., 8.],
  74. [9., 10., 11., 12.]],
  75. [[13., 14., 15., 16.],
  76. [17., 18., 19., 20.],
  77. [21., 22., 23., 24.]]],
  78. [[[73., 74., 75., 76.],
  79. [77., 78., 79., 80.],
  80. [81., 82., 83., 84.]],
  81. [[85., 86., 87., 88.],
  82. [89., 90., 91., 92.],
  83. [93., 94., 95., 96.]]],
  84. [[[25., 26., 27., 28.],
  85. [29., 30., 31., 32.],
  86. [33., 34., 35., 36.]],
  87. [[37., 38., 39., 40.],
  88. [41., 42., 43., 44.],
  89. [45., 46., 47., 48.]]],
  90. [[[49., 50., 51., 52.],
  91. [53., 54., 55., 56.],
  92. [57., 58., 59., 60.]],
  93. [[61., 62., 63., 64.],
  94. [65., 66., 67., 68.],
  95. [69., 70., 71., 72.]]]])
  96. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  97. @pytest.mark.level0
  98. @pytest.mark.platform_x86_cpu
  99. @pytest.mark.env_onecard
  100. def test_scatter_add_small_float32_use_locking_false():
  101. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  102. indices = Tensor(np.array([1, 0]).astype(np.int32))
  103. updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
  104. output = scatter_add_use_locking_false_net(inputx, indices, updates)
  105. expected = np.array([[3., 4., 5.],
  106. [0., 1., 2.]])
  107. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  108. @pytest.mark.level0
  109. @pytest.mark.platform_x86_cpu
  110. @pytest.mark.env_onecard
  111. def test_scatter_add_input_less_than_1_float32():
  112. inputx = Tensor(np.array([[0.214141, 0.415151, 0.51516],
  113. [0.876542, 0.451611, 0.55112],
  114. [0.111244, 0.633333, 0.34444]]).astype(np.float32))
  115. indices = Tensor(np.array([[[1, 0, 2],
  116. [2, 2, 0]],
  117. [[1, 0, 1],
  118. [2, 1, 2]]]).astype(np.int32))
  119. updates = Tensor(np.arange(34, 70).reshape((2, 2, 3, 3)).astype(np.float32))
  120. output = scatter_add_net(inputx, indices, updates)
  121. expected = np.array([[141.21414, 144.41515, 147.51517],
  122. [208.87654, 212.45161, 216.55112],
  123. [257.11124, 262.63333, 267.34442]], dtype=np.float32)
  124. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  125. @pytest.mark.level0
  126. @pytest.mark.platform_x86_cpu
  127. @pytest.mark.env_onecard
  128. def test_scatter_add_float16():
  129. inputx = Tensor(np.zeros((2, 3)).astype(np.float16))
  130. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  131. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float16))
  132. output = scatter_add_net(inputx, indices, updates)
  133. expected = np.array([[6., 8., 10.],
  134. [12., 14., 16.]])
  135. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  136. @pytest.mark.level0
  137. @pytest.mark.platform_x86_cpu
  138. @pytest.mark.env_onecard
  139. def test_scatter_add_large_float16():
  140. inputx = Tensor(np.zeros((2, 3, 4)).astype(np.float16))
  141. indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))
  142. updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16))
  143. output = scatter_add_net(inputx, indices, updates)
  144. expected = np.array([[[138., 140., 142., 144.],
  145. [146., 148., 150., 152.],
  146. [154., 156., 158., 160.]],
  147. [[186., 188., 190., 192.],
  148. [194., 196., 198., 200.],
  149. [202., 204., 206., 208.]]])
  150. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  151. @pytest.mark.level0
  152. @pytest.mark.platform_x86_cpu
  153. @pytest.mark.env_onecard
  154. def test_scatter_add_disordered_float16():
  155. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.float16)))
  156. indices = Tensor(np.array([[[0, 1, 2],
  157. [2, 1, 0]],
  158. [[0, 0, 0],
  159. [2, 2, 2]]]).astype(np.int32))
  160. updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.float16))
  161. output = scatter_add_net(inputx, indices, updates)
  162. expected = np.array([[464., 468., 472., 476.],
  163. [187., 188., 189., 190.],
  164. [492., 496., 500., 504.]])
  165. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  166. @pytest.mark.level0
  167. @pytest.mark.platform_x86_cpu
  168. @pytest.mark.env_onecard
  169. def test_scatter_add_large_int32():
  170. inputx = Tensor(np.zeros((2, 3, 4)).astype(np.int32))
  171. indices = Tensor(np.array([[0, 0], [1, 1]]).astype(np.int32))
  172. updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
  173. output = scatter_add_net(inputx, indices, updates)
  174. expected = np.array([[[138., 140., 142., 144.],
  175. [146., 148., 150., 152.],
  176. [154., 156., 158., 160.]],
  177. [[186., 188., 190., 192.],
  178. [194., 196., 198., 200.],
  179. [202., 204., 206., 208.]]]).astype(np.int32)
  180. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  181. @pytest.mark.level0
  182. @pytest.mark.platform_x86_cpu
  183. @pytest.mark.env_onecard
  184. def test_scatter_add_disordered_int32():
  185. inputx = Tensor(np.flip(np.arange(34, 46).reshape(3, 4).astype(np.int32)))
  186. indices = Tensor(np.array([[[0, 1, 2],
  187. [2, 1, 0]],
  188. [[0, 0, 0],
  189. [2, 2, 2]]]).astype(np.int32))
  190. updates = Tensor(np.arange(63, 111).reshape((2, 2, 3, 4)).astype(np.int32))
  191. output = scatter_add_net(inputx, indices, updates)
  192. expected = np.array([[464., 468., 472., 476.],
  193. [187., 188., 189., 190.],
  194. [492., 496., 500., 504.]]).astype(np.int32)
  195. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  196. class TestScatterSubNet(nn.Cell):
  197. def __init__(self, lock, inputx, indices, updates):
  198. super(TestScatterSubNet, self).__init__()
  199. self.scatter_sub = P.ScatterSub(use_locking=lock)
  200. self.inputx = Parameter(inputx, name="inputx")
  201. self.indices = Parameter(indices, name="indices")
  202. self.updates = Parameter(updates, name="updates")
  203. def construct(self):
  204. out = self.scatter_sub(self.inputx, self.indices, self.updates)
  205. return out
  206. def scatter_sub_net(inputx, indices, updates):
  207. lock = True
  208. net = TestScatterSubNet(lock, inputx, indices, updates)
  209. return net()
  210. def scatter_sub_use_locking_false_net(inputx, indices, updates):
  211. lock = False
  212. net = TestScatterSubNet(lock, inputx, indices, updates)
  213. return net()
  214. @pytest.mark.level0
  215. @pytest.mark.platform_x86_cpu
  216. @pytest.mark.env_onecard
  217. def test_scatter_sub_input_updated():
  218. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  219. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  220. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  221. lock = True
  222. net = TestScatterSubNet(lock, inputx, indices, updates)
  223. net()
  224. expected = np.array([[-6., -8., -10.],
  225. [-12., -14., -16.]])
  226. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  227. @pytest.mark.level0
  228. @pytest.mark.platform_x86_cpu
  229. @pytest.mark.env_onecard
  230. def test_scatter_sub_large_shape_float32():
  231. inputx = Tensor(np.ones((4, 2, 3, 4)).astype(np.float32))
  232. indices = Tensor(np.array([[0, 2], [3, 1]]).astype(np.int32))
  233. updates = Tensor(np.arange(96).reshape((2, 2, 2, 3, 4)).astype(np.float32))
  234. output = scatter_sub_net(inputx, indices, updates)
  235. expected = np.array(
  236. [[[[1.0, 0.0, -1.0, -2.0],
  237. [-3.0, -4.0, -5.0, -6.0],
  238. [-7.0, -8.0, -9.0, -10.0]],
  239. [[-11.0, -12.0, -13.0, -14.0],
  240. [-15.0, -16.0, -17.0, -18.0],
  241. [-19.0, -20.0, -21.0, -22.0]]],
  242. [[[-71.0, -72.0, -73.0, -74.0],
  243. [-75.0, -76.0, -77.0, -78.0],
  244. [-79.0, -80.0, -81.0, -82.0]],
  245. [[-83.0, -84.0, -85.0, -86.0],
  246. [-87.0, -88.0, -89.0, -90.0],
  247. [-91.0, -92.0, -93.0, -94.0]]],
  248. [[[-23.0, -24.0, -25.0, -26.0],
  249. [-27.0, -28.0, -29.0, -30.0],
  250. [-31.0, -32.0, -33.0, -34.0]],
  251. [[-35.0, -36.0, -37.0, -38.0],
  252. [-39.0, -40.0, -41.0, -42.0],
  253. [-43.0, -44.0, -45.0, -46.0]]],
  254. [[[-47.0, -48.0, -49.0, -50.0],
  255. [-51.0, -52.0, -53.0, -54.0],
  256. [-55.0, -56.0, -57.0, -58.0]],
  257. [[-59.0, -60.0, -61.0, -62.0],
  258. [-63.0, -64.0, -65.0, -66.0],
  259. [-67.0, -68.0, -69.0, -70.0]]]])
  260. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  261. @pytest.mark.level0
  262. @pytest.mark.platform_x86_cpu
  263. @pytest.mark.env_onecard
  264. def test_scatter_sub_small_float32_use_locking_false():
  265. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  266. indices = Tensor(np.array([1, 0]).astype(np.int32))
  267. updates = Tensor(np.arange(6).reshape((2, 3)).astype(np.float32))
  268. output = scatter_sub_use_locking_false_net(inputx, indices, updates)
  269. expected = np.array([[-3., -4., -5.],
  270. [-0., -1., -2.]])
  271. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  272. class TestScatterMulNet(nn.Cell):
  273. def __init__(self, lock, inputx, indices, updates):
  274. super(TestScatterMulNet, self).__init__()
  275. self.scatter_mul = P.ScatterMul(use_locking=lock)
  276. self.inputx = Parameter(inputx, name="inputx")
  277. self.indices = Parameter(indices, name="indices")
  278. self.updates = Parameter(updates, name="updates")
  279. def construct(self):
  280. out = self.scatter_mul(self.inputx, self.indices, self.updates)
  281. return out
  282. def scatter_mul_net(inputx, indices, updates):
  283. lock = True
  284. net = TestScatterMulNet(lock, inputx, indices, updates)
  285. return net()
  286. def scatter_mul_use_locking_false_net(inputx, indices, updates):
  287. lock = False
  288. net = TestScatterMulNet(lock, inputx, indices, updates)
  289. return net()
  290. @pytest.mark.level0
  291. @pytest.mark.platform_x86_cpu
  292. @pytest.mark.env_onecard
  293. def test_scatter_mul_input_updated():
  294. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  295. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  296. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  297. lock = True
  298. net = TestScatterMulNet(lock, inputx, indices, updates)
  299. net()
  300. expected = np.array([[0., 7., 16.],
  301. [27., 40., 55.]])
  302. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  303. @pytest.mark.level0
  304. @pytest.mark.platform_x86_cpu
  305. @pytest.mark.env_onecard
  306. def test_scatter_mul_output_updated_float32():
  307. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  308. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  309. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  310. output = scatter_mul_net(inputx, indices, updates)
  311. expected = np.array([[0., 7., 16.],
  312. [27., 40., 55.]])
  313. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  314. @pytest.mark.level0
  315. @pytest.mark.platform_x86_cpu
  316. @pytest.mark.env_onecard
  317. def test_scatter_mul_small_float32_use_locking_false():
  318. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  319. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  320. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  321. output = scatter_mul_use_locking_false_net(inputx, indices, updates)
  322. expected = np.array([[0., 7., 16.],
  323. [27., 40., 55.]])
  324. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  325. class TestScatterDivNet(nn.Cell):
  326. def __init__(self, lock, inputx, indices, updates):
  327. super(TestScatterDivNet, self).__init__()
  328. self.scatter_div = P.ScatterDiv(use_locking=lock)
  329. self.inputx = Parameter(inputx, name="inputx")
  330. self.indices = Parameter(indices, name="indices")
  331. self.updates = Parameter(updates, name="updates")
  332. def construct(self):
  333. out = self.scatter_div(self.inputx, self.indices, self.updates)
  334. return out
  335. def scatter_div_net(inputx, indices, updates):
  336. lock = True
  337. net = TestScatterDivNet(lock, inputx, indices, updates)
  338. return net()
  339. def scatter_div_use_locking_false_net(inputx, indices, updates):
  340. lock = False
  341. net = TestScatterDivNet(lock, inputx, indices, updates)
  342. return net()
  343. @pytest.mark.level0
  344. @pytest.mark.platform_x86_cpu
  345. @pytest.mark.env_onecard
  346. def test_scatter_div_input_updated():
  347. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  348. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  349. updates = Tensor(np.arange(1, 13).reshape((2, 2, 3)).astype(np.float32))
  350. lock = True
  351. net = TestScatterDivNet(lock, inputx, indices, updates)
  352. net()
  353. expected = np.array([[0., 0., 0.],
  354. [0., 0., 0.]])
  355. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  356. @pytest.mark.level0
  357. @pytest.mark.platform_x86_cpu
  358. @pytest.mark.env_onecard
  359. def test_scatter_div_output_updated_float32():
  360. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  361. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  362. updates = Tensor(np.arange(1, 13).reshape((2, 2, 3)).astype(np.float32))
  363. output = scatter_div_net(inputx, indices, updates)
  364. expected = np.array([[0., 0., 0.],
  365. [0., 0., 0.]])
  366. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  367. @pytest.mark.level0
  368. @pytest.mark.platform_x86_cpu
  369. @pytest.mark.env_onecard
  370. def test_scatter_div_small_float32_use_locking_false():
  371. inputx = Tensor(np.ones((2, 3)).astype(np.float32) * 10)
  372. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  373. updates = Tensor(np.ones(12).reshape((2, 2, 3)).astype(np.float32))
  374. output = scatter_div_use_locking_false_net(inputx, indices, updates)
  375. expected = np.array([[10., 10., 10.],
  376. [10., 10., 10.]])
  377. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  378. class TestScatterMaxNet(nn.Cell):
  379. def __init__(self, lock, inputx, indices, updates):
  380. super(TestScatterMaxNet, self).__init__()
  381. self.scatter_max = P.ScatterMax(use_locking=lock)
  382. self.inputx = Parameter(inputx, name="inputx")
  383. self.indices = Parameter(indices, name="indices")
  384. self.updates = Parameter(updates, name="updates")
  385. def construct(self):
  386. out = self.scatter_max(self.inputx, self.indices, self.updates)
  387. return out
  388. def scatter_max_net(inputx, indices, updates):
  389. lock = True
  390. net = TestScatterMaxNet(lock, inputx, indices, updates)
  391. return net()
  392. def scatter_max_use_locking_false_net(inputx, indices, updates):
  393. lock = False
  394. net = TestScatterMaxNet(lock, inputx, indices, updates)
  395. return net()
  396. @pytest.mark.level0
  397. @pytest.mark.platform_x86_cpu
  398. @pytest.mark.env_onecard
  399. def test_scatter_max_input_updated():
  400. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  401. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  402. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  403. lock = True
  404. net = TestScatterMaxNet(lock, inputx, indices, updates)
  405. net()
  406. expected = np.array([[6., 7., 8.],
  407. [9., 10., 11.]])
  408. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  409. @pytest.mark.level0
  410. @pytest.mark.platform_x86_cpu
  411. @pytest.mark.env_onecard
  412. def test_scatter_max_output_updated_float32():
  413. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  414. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  415. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  416. output = scatter_max_net(inputx, indices, updates)
  417. expected = np.array([[6., 7., 8.],
  418. [9., 10., 11.]])
  419. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  420. @pytest.mark.level0
  421. @pytest.mark.platform_x86_cpu
  422. @pytest.mark.env_onecard
  423. def test_scatter_max_small_float32_use_locking_false():
  424. inputx = Tensor(np.ones((2, 3)).astype(np.float32) * 10)
  425. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  426. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  427. output = scatter_max_use_locking_false_net(inputx, indices, updates)
  428. expected = np.array([[10., 10., 10.],
  429. [10., 10., 11.]])
  430. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  431. class TestScatterMinNet(nn.Cell):
  432. def __init__(self, lock, inputx, indices, updates):
  433. super(TestScatterMinNet, self).__init__()
  434. self.scatter_min = P.ScatterMin(use_locking=lock)
  435. self.inputx = Parameter(inputx, name="inputx")
  436. self.indices = Parameter(indices, name="indices")
  437. self.updates = Parameter(updates, name="updates")
  438. def construct(self):
  439. out = self.scatter_min(self.inputx, self.indices, self.updates)
  440. return out
  441. def scatter_min_net(inputx, indices, updates):
  442. lock = True
  443. net = TestScatterMinNet(lock, inputx, indices, updates)
  444. return net()
  445. def scatter_min_use_locking_false_net(inputx, indices, updates):
  446. lock = False
  447. net = TestScatterMinNet(lock, inputx, indices, updates)
  448. return net()
  449. @pytest.mark.level0
  450. @pytest.mark.platform_x86_cpu
  451. @pytest.mark.env_onecard
  452. def test_scatter_min_input_updated():
  453. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  454. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  455. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  456. lock = True
  457. net = TestScatterMinNet(lock, inputx, indices, updates)
  458. net()
  459. expected = np.array([[0., 0., 0.],
  460. [0., 0., 0.]])
  461. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  462. @pytest.mark.level0
  463. @pytest.mark.platform_x86_cpu
  464. @pytest.mark.env_onecard
  465. def test_scatter_min_output_updated_float32():
  466. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  467. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  468. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  469. output = scatter_min_net(inputx, indices, updates)
  470. expected = np.array([[0., 1., 1.],
  471. [1., 1., 1.]])
  472. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  473. @pytest.mark.level0
  474. @pytest.mark.platform_x86_cpu
  475. @pytest.mark.env_onecard
  476. def test_scatter_min_small_float32_use_locking_false():
  477. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  478. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  479. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  480. output = scatter_min_use_locking_false_net(inputx, indices, updates)
  481. expected = np.array([[0., 1., 1.],
  482. [1., 1., 1.]])
  483. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  484. class TestScatterUpdateNet(nn.Cell):
  485. def __init__(self, lock, inputx, indices, updates):
  486. super(TestScatterUpdateNet, self).__init__()
  487. self.scatter_update = P.ScatterUpdate(use_locking=lock)
  488. self.inputx = Parameter(inputx, name="inputx")
  489. self.indices = Parameter(indices, name="indices")
  490. self.updates = Parameter(updates, name="updates")
  491. def construct(self):
  492. out = self.scatter_update(self.inputx, self.indices, self.updates)
  493. return out
  494. def scatter_update_net(inputx, indices, updates):
  495. lock = True
  496. net = TestScatterUpdateNet(lock, inputx, indices, updates)
  497. return net()
  498. def scatter_update_use_locking_false_net(inputx, indices, updates):
  499. lock = False
  500. net = TestScatterUpdateNet(lock, inputx, indices, updates)
  501. return net()
  502. @pytest.mark.level0
  503. @pytest.mark.platform_x86_cpu
  504. @pytest.mark.env_onecard
  505. def test_scatter_update_input_updated():
  506. inputx = Tensor(np.zeros((2, 3)).astype(np.float32))
  507. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  508. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  509. lock = True
  510. net = TestScatterUpdateNet(lock, inputx, indices, updates)
  511. net()
  512. expected = np.array([[6., 7., 8.],
  513. [9., 10., 11.]])
  514. np.testing.assert_array_almost_equal(net.inputx.asnumpy(), expected)
  515. @pytest.mark.level0
  516. @pytest.mark.platform_x86_cpu
  517. @pytest.mark.env_onecard
  518. def test_scatter_update_output_updated_float32():
  519. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  520. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  521. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  522. output = scatter_update_net(inputx, indices, updates)
  523. expected = np.array([[6., 7., 8.],
  524. [9., 10., 11.]])
  525. np.testing.assert_array_almost_equal(output.asnumpy(), expected)
  526. @pytest.mark.level0
  527. @pytest.mark.platform_x86_cpu
  528. @pytest.mark.env_onecard
  529. def test_scatter_update_small_float32_use_locking_false():
  530. inputx = Tensor(np.ones((2, 3)).astype(np.float32))
  531. indices = Tensor(np.array([[0, 1], [0, 1]]).astype(np.int32))
  532. updates = Tensor(np.arange(12).reshape((2, 2, 3)).astype(np.float32))
  533. output = scatter_update_use_locking_false_net(inputx, indices, updates)
  534. expected = np.array([[6., 7., 8.],
  535. [9., 10., 11.]])
  536. np.testing.assert_array_almost_equal(output.asnumpy(), expected)