You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_setitem.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_setitem """
  16. import numpy as onp
  17. import pytest
  18. from mindspore import Tensor, context
  19. from mindspore.nn import Cell
  20. from mindspore import dtype as mstype
  21. def setup_module():
  22. context.set_context(mode=context.PYNATIVE_MODE)
  23. def setup_testcase(input_np, case_fn):
  24. input_ms = Tensor(input_np)
  25. class TensorSetItem(Cell):
  26. def construct(self, x):
  27. return case_fn(x)
  28. class NumpySetItem():
  29. def __call__(self, x):
  30. return case_fn(x)
  31. out_ms = TensorSetItem()(input_ms)
  32. out_np = NumpySetItem()(input_np)
  33. assert onp.all(out_ms.asnumpy() == out_np)
  34. class TensorSetItemByList(Cell):
  35. def construct(self, x):
  36. x[[0, 1], [1, 2], [1, 3]] = [3, 4]
  37. x[([0, 1], [0, 2], [1, 1])] = [10, 5]
  38. x[[0, 1], ..., [0, 1]] = 4
  39. return x
  40. class NumpySetItemByList():
  41. def __call__(self, x):
  42. x[[0, 1], [1, 2], [1, 3]] = [3, 4]
  43. x[([0, 1], [0, 2], [1, 1])] = [10, 5]
  44. x[[0, 1], ..., [0, 1]] = 4
  45. return x
  46. @pytest.mark.level0
  47. @pytest.mark.platform_arm_ascend_training
  48. @pytest.mark.platform_x86_ascend_training
  49. @pytest.mark.platform_x86_gpu_training
  50. @pytest.mark.env_onecard
  51. def test_setitem_by_list():
  52. x = onp.ones((2, 3, 4), dtype=onp.float32)
  53. def cases(x):
  54. x[[0, 1], [1, 2], [1, 3]] = [3, 4]
  55. x[([0, 1], [0, 2], [1, 1])] = [10, 5]
  56. x[[0, 1], ..., [0, 1]] = 4
  57. return x
  58. setup_testcase(x, cases)
  59. @pytest.mark.level0
  60. @pytest.mark.platform_arm_ascend_training
  61. @pytest.mark.platform_x86_ascend_training
  62. @pytest.mark.platform_x86_gpu_training
  63. @pytest.mark.env_onecard
  64. def test_setitem_with_sequence():
  65. x = onp.ones((2, 3, 4), dtype=onp.float32)
  66. def cases(x):
  67. x[...] = [3]
  68. x[..., 1] = ([1, 2, 3], [4, 5, 6])
  69. x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
  70. x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
  71. return x
  72. setup_testcase(x, cases)
  73. @pytest.mark.level0
  74. @pytest.mark.platform_arm_ascend_training
  75. @pytest.mark.platform_x86_ascend_training
  76. @pytest.mark.platform_x86_gpu_training
  77. @pytest.mark.env_onecard
  78. def test_setitem_dtype():
  79. x = onp.ones((2, 3, 4), dtype=onp.float32)
  80. def cases(x):
  81. x[...] = 3
  82. x[..., 1] = 3.0
  83. x[0] = True
  84. x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11])
  85. return x
  86. setup_testcase(x, cases)
  87. @pytest.mark.level0
  88. @pytest.mark.platform_arm_ascend_training
  89. @pytest.mark.platform_x86_ascend_training
  90. @pytest.mark.platform_x86_gpu_training
  91. @pytest.mark.env_onecard
  92. def test_setitem_by_tuple_with_int():
  93. x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
  94. def cases(x):
  95. x[..., 2, False, 1] = -1
  96. x[0, True, 0, None, True] = -2
  97. x[0, ..., None] = -3
  98. x[..., 0, None, 1, True, True, None] = -4
  99. return x
  100. setup_testcase(x, cases)
  101. @pytest.mark.level0
  102. @pytest.mark.platform_arm_ascend_training
  103. @pytest.mark.platform_x86_ascend_training
  104. @pytest.mark.platform_x86_gpu_training
  105. @pytest.mark.env_onecard
  106. def test_setitem_by_tuple_with_list():
  107. x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
  108. def cases(x):
  109. x[..., 2, False, 1] = [-1]
  110. x[0, True, 0, None, True] = [-2, -2, -2, -2]
  111. x[0, ..., None] = [[-3], [-3], [-3], [-3]]
  112. x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
  113. x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]]
  114. return x
  115. setup_testcase(x, cases)
  116. @pytest.mark.level0
  117. @pytest.mark.platform_arm_ascend_training
  118. @pytest.mark.platform_x86_ascend_training
  119. @pytest.mark.platform_x86_gpu_training
  120. @pytest.mark.env_onecard
  121. def test_setitem_by_nested_unit_list():
  122. x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
  123. def cases(x):
  124. x[[[[0]]], True] = -1
  125. x[[1], ..., [[[[2]]]]] = -2
  126. x[0, [[[2]]], [1]] = -3
  127. return x
  128. setup_testcase(x, cases)
  129. @pytest.mark.level0
  130. @pytest.mark.platform_arm_ascend_training
  131. @pytest.mark.platform_x86_ascend_training
  132. @pytest.mark.platform_x86_gpu_training
  133. @pytest.mark.env_onecard
  134. def test_setitem_with_broadcast():
  135. x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32)
  136. v1 = onp.full((1, 4, 5), -1).tolist()
  137. v2 = onp.full((4, 1, 6), -2).tolist()
  138. def cases(x):
  139. x[..., 4] = v1
  140. x[0, 2] = v2
  141. x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]]
  142. x[0, ..., 1, 3, 5] = -4
  143. return x
  144. setup_testcase(x, cases)
  145. @pytest.mark.level0
  146. @pytest.mark.platform_arm_ascend_training
  147. @pytest.mark.platform_x86_ascend_training
  148. @pytest.mark.platform_x86_gpu_training
  149. @pytest.mark.env_onecard
  150. def test_setitem_mul_by_scalar():
  151. x = onp.ones((4, 5), dtype=onp.float32)
  152. def cases(x):
  153. x[1, :] = x[1, :]*2
  154. x[:, 2] = x[:, 3]*3.0
  155. return x
  156. setup_testcase(x, cases)
  157. @pytest.mark.level0
  158. @pytest.mark.platform_arm_ascend_training
  159. @pytest.mark.platform_x86_ascend_training
  160. @pytest.mark.platform_x86_gpu_training
  161. @pytest.mark.env_onecard
  162. def test_setitem_by_slice():
  163. x = onp.ones((3, 4, 5), dtype=onp.float32)
  164. def cases(x):
  165. x[1:2] = 2
  166. x[-3:1] = 3
  167. x[-10:3:2] = 4
  168. x[5:0:3] = 5
  169. x[5:5:5] = 6
  170. x[-1:2] = 7
  171. x[1:0:-1] = 8
  172. return x
  173. setup_testcase(x, cases)
  174. @pytest.mark.level0
  175. @pytest.mark.platform_arm_ascend_training
  176. @pytest.mark.platform_x86_ascend_training
  177. @pytest.mark.platform_x86_gpu_training
  178. @pytest.mark.env_onecard
  179. def test_setitem_by_tuple_of_slices():
  180. x = onp.ones((3, 4, 5), dtype=onp.float32)
  181. def cases(x):
  182. x[1:2, 2] = 2
  183. x[0, -4:1] = 3
  184. x[1, -10:3:2] = 4
  185. x[5:0:3, 3] = 5
  186. x[1:1, 2:2] = 6
  187. return x
  188. setup_testcase(x, cases)
  189. class TensorItemSetWithNumber(Cell):
  190. def construct(self, tensor, number_value):
  191. ret = tensor.itemset(number_value)
  192. return ret
  193. @pytest.mark.level0
  194. @pytest.mark.platform_arm_ascend_training
  195. @pytest.mark.platform_x86_ascend_training
  196. @pytest.mark.platform_x86_gpu_training
  197. @pytest.mark.env_onecard
  198. def test_itemset_with_number():
  199. net = TensorItemSetWithNumber()
  200. input_1d_np = onp.ndarray([1]).astype(onp.float32)
  201. input_1d_ms = Tensor(input_1d_np, mstype.float32)
  202. input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32)
  203. input_3d_ms = Tensor(input_3d_np, mstype.float32)
  204. value_np_1, value_np_2 = 1, 2.0
  205. output_1d_ms_1 = net(input_1d_ms, value_np_1)
  206. output_1d_ms_2 = net(input_1d_ms, value_np_2)
  207. input_1d_np.itemset(value_np_1)
  208. assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np)
  209. input_1d_np.itemset(value_np_2)
  210. assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np)
  211. with pytest.raises(IndexError):
  212. net(input_3d_ms, value_np_1)
  213. with pytest.raises(IndexError):
  214. net(input_3d_ms, value_np_2)
  215. class TensorItemSetByItemWithNumber(Cell):
  216. def construct(self, tensor, index, number_value):
  217. ret = tensor.itemset(index, number_value)
  218. return ret
  219. @pytest.mark.level0
  220. @pytest.mark.platform_arm_ascend_training
  221. @pytest.mark.platform_x86_ascend_training
  222. @pytest.mark.platform_x86_gpu_training
  223. @pytest.mark.env_onecard
  224. def test_setitem_dim_expand():
  225. x = onp.ones((2, 3, 4), dtype=onp.float32)
  226. def cases(x):
  227. x[None, True, [1, 0], (False, True, True), [2]] = 2
  228. x[([[0]]), ..., [[1]]] = [[[3, 3, 3]]]
  229. x[0:1] = [[2, 3, 4, 5]]
  230. x[..., (0, 1, 2), None, :, True, None] = [[[3], [3], [3], [3]]]
  231. return x
  232. setup_testcase(x, cases)
  233. @pytest.mark.level0
  234. @pytest.mark.platform_arm_ascend_training
  235. @pytest.mark.platform_x86_ascend_training
  236. @pytest.mark.platform_x86_gpu_training
  237. @pytest.mark.env_onecard
  238. def test_itemset_by_number_with_number():
  239. net = TensorItemSetByItemWithNumber()
  240. input_1d_np = onp.ndarray([1]).astype(onp.float32)
  241. input_1d_ms = Tensor(input_1d_np, mstype.float32)
  242. input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32)
  243. input_3d_ms = Tensor(input_3d_np, mstype.float32)
  244. index_np_1, index_np_2, index_np_3, index_np_4 = 0, 30, 60, 2.0
  245. value_np_1, value_np_2 = 1, 2.0
  246. output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1)
  247. output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2)
  248. output_3d_ms_1 = net(input_3d_ms, index_np_1, value_np_1)
  249. output_3d_ms_2 = net(output_3d_ms_1, index_np_1, value_np_2)
  250. output_3d_ms_3 = net(output_3d_ms_2, index_np_2, value_np_1)
  251. output_3d_ms_4 = net(output_3d_ms_3, index_np_2, value_np_2)
  252. input_1d_np.itemset(index_np_1, value_np_1)
  253. assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np)
  254. input_1d_np.itemset(index_np_1, value_np_2)
  255. assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np)
  256. input_3d_np.itemset(index_np_1, value_np_1)
  257. assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np)
  258. input_3d_np.itemset(index_np_1, value_np_2)
  259. assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np)
  260. input_3d_np.itemset(index_np_2, value_np_1)
  261. assert onp.all(output_3d_ms_3.asnumpy() == input_3d_np)
  262. input_3d_np.itemset(index_np_2, value_np_2)
  263. assert onp.all(output_3d_ms_4.asnumpy() == input_3d_np)
  264. with pytest.raises(IndexError):
  265. net(input_1d_ms, index_np_2, value_np_1)
  266. with pytest.raises(IndexError):
  267. net(input_1d_ms, index_np_2, value_np_2)
  268. with pytest.raises(TypeError):
  269. net(input_1d_ms, index_np_4, value_np_1)
  270. with pytest.raises(TypeError):
  271. net(input_1d_ms, index_np_4, value_np_2)
  272. with pytest.raises(IndexError):
  273. net(input_3d_ms, index_np_3, value_np_1)
  274. with pytest.raises(IndexError):
  275. net(input_3d_ms, index_np_3, value_np_2)
  276. with pytest.raises(TypeError):
  277. net(input_3d_ms, index_np_4, value_np_1)
  278. with pytest.raises(TypeError):
  279. net(input_3d_ms, index_np_4, value_np_2)
  280. @pytest.mark.level0
  281. @pytest.mark.platform_arm_ascend_training
  282. @pytest.mark.platform_x86_ascend_training
  283. @pytest.mark.platform_x86_gpu_training
  284. @pytest.mark.env_onecard
  285. def test_itemset_by_tuple_with_number():
  286. net = TensorItemSetByItemWithNumber()
  287. input_1d_np = onp.ndarray([1]).astype(onp.float32)
  288. input_1d_ms = Tensor(input_1d_np, mstype.float32)
  289. input_3d_np = onp.arange(60).reshape(3, 4, 5).astype(onp.int32)
  290. input_3d_ms = Tensor(input_3d_np, mstype.float32)
  291. index_np_1, index_np_2, index_np_3, index_np_4, index_np_5 = (0,), (1, 2), (1, 1, 0), (3, 4, 5), (1, 2, 3, 4)
  292. value_np_1, value_np_2 = 1, 2.0
  293. output_1d_ms_1 = net(input_1d_ms, index_np_1, value_np_1)
  294. input_1d_np.itemset(index_np_1, value_np_1)
  295. assert onp.all(output_1d_ms_1.asnumpy() == input_1d_np)
  296. output_1d_ms_2 = net(input_1d_ms, index_np_1, value_np_2)
  297. input_1d_np.itemset(index_np_1, value_np_2)
  298. assert onp.all(output_1d_ms_2.asnumpy() == input_1d_np)
  299. output_3d_ms_1 = net(input_3d_ms, index_np_3, value_np_1)
  300. input_3d_np.itemset(index_np_3, value_np_1)
  301. assert onp.all(output_3d_ms_1.asnumpy() == input_3d_np)
  302. output_3d_ms_2 = net(input_3d_ms, index_np_3, value_np_2)
  303. input_3d_np.itemset(index_np_3, value_np_2)
  304. assert onp.all(output_3d_ms_2.asnumpy() == input_3d_np)
  305. with pytest.raises(ValueError):
  306. net(input_1d_ms, index_np_2, value_np_1)
  307. with pytest.raises(ValueError):
  308. net(input_1d_ms, index_np_2, value_np_2)
  309. with pytest.raises(ValueError):
  310. net(input_3d_ms, index_np_1, value_np_1)
  311. with pytest.raises(ValueError):
  312. net(input_3d_ms, index_np_1, value_np_2)
  313. with pytest.raises(ValueError):
  314. net(input_3d_ms, index_np_2, value_np_1)
  315. with pytest.raises(ValueError):
  316. net(input_3d_ms, index_np_2, value_np_2)
  317. with pytest.raises(IndexError):
  318. net(input_3d_ms, index_np_4, value_np_1)
  319. with pytest.raises(IndexError):
  320. net(input_3d_ms, index_np_4, value_np_2)
  321. with pytest.raises(ValueError):
  322. net(input_3d_ms, index_np_5, value_np_1)
  323. with pytest.raises(ValueError):
  324. net(input_3d_ms, index_np_5, value_np_2)