You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_setitem.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_setitem """
  16. import numpy as onp
  17. import pytest
  18. from mindspore import Tensor, context
  19. from mindspore.nn import Cell
  20. def setup_module():
  21. context.set_context(mode=context.GRAPH_MODE)
  22. def setup_testcase(input_np, case_fn):
  23. input_ms = Tensor(input_np)
  24. class TensorSetItem(Cell):
  25. def construct(self, x):
  26. return case_fn(x)
  27. class NumpySetItem():
  28. def __call__(self, x):
  29. return case_fn(x)
  30. out_ms = TensorSetItem()(input_ms)
  31. out_np = NumpySetItem()(input_np)
  32. assert onp.all(out_ms.asnumpy() == out_np)
  33. class TensorSetItemByList(Cell):
  34. def construct(self, x):
  35. x[[0, 1], [1, 2], [1, 3]] = [3, 4]
  36. x[([0, 1], [0, 2], [1, 1])] = [10, 5]
  37. x[[0, 1], ..., [0, 1]] = 4
  38. return x
  39. class NumpySetItemByList():
  40. def __call__(self, x):
  41. x[[0, 1], [1, 2], [1, 3]] = [3, 4]
  42. x[([0, 1], [0, 2], [1, 1])] = [10, 5]
  43. x[[0, 1], ..., [0, 1]] = 4
  44. return x
  45. @pytest.mark.level0
  46. @pytest.mark.platform_arm_ascend_training
  47. @pytest.mark.platform_x86_ascend_training
  48. @pytest.mark.platform_x86_gpu_training
  49. @pytest.mark.env_onecard
  50. def test_setitem_by_list():
  51. x = onp.ones((2, 3, 4), dtype=onp.float32)
  52. def cases(x):
  53. x[[0, 1], [1, 2], [1, 3]] = [3, 4]
  54. x[([0, 1], [0, 2], [1, 1])] = [10, 5]
  55. x[[0, 1], ..., [0, 1]] = 4
  56. return x
  57. setup_testcase(x, cases)
  58. @pytest.mark.level0
  59. @pytest.mark.platform_arm_ascend_training
  60. @pytest.mark.platform_x86_ascend_training
  61. @pytest.mark.platform_x86_gpu_training
  62. @pytest.mark.env_onecard
  63. def test_setitem_with_sequence():
  64. x = onp.ones((2, 3, 4), dtype=onp.float32)
  65. def cases(x):
  66. x[...] = [3]
  67. x[..., 1] = ([1, 2, 3], [4, 5, 6])
  68. x[0] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
  69. x[1:2] = ((0, 1, 2, 3), (4, 5, 6, 7), [8, 9, 10, 11])
  70. return x
  71. setup_testcase(x, cases)
  72. @pytest.mark.level0
  73. @pytest.mark.platform_arm_ascend_training
  74. @pytest.mark.platform_x86_ascend_training
  75. @pytest.mark.platform_x86_gpu_training
  76. @pytest.mark.env_onecard
  77. def test_setitem_dtype():
  78. x = onp.ones((2, 3, 4), dtype=onp.float32)
  79. def cases(x):
  80. x[...] = 3
  81. x[..., 1] = 3.0
  82. x[0] = True
  83. x[1:2] = ((0, False, 2, 3), (4.0, 5, 6, 7), [True, 9, 10, 11])
  84. return x
  85. setup_testcase(x, cases)
  86. @pytest.mark.level0
  87. @pytest.mark.platform_arm_ascend_training
  88. @pytest.mark.platform_x86_ascend_training
  89. @pytest.mark.platform_x86_gpu_training
  90. @pytest.mark.env_onecard
  91. def test_setitem_by_tuple_with_int():
  92. x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
  93. def cases(x):
  94. x[..., 2, False, 1] = -1
  95. x[0, True, 0, None, True] = -2
  96. x[0, ..., None] = -3
  97. x[..., 0, None, 1, True, True, None] = -4
  98. return x
  99. setup_testcase(x, cases)
  100. @pytest.mark.level0
  101. @pytest.mark.platform_arm_ascend_training
  102. @pytest.mark.platform_x86_ascend_training
  103. @pytest.mark.platform_x86_gpu_training
  104. @pytest.mark.env_onecard
  105. def test_setitem_by_tuple_with_list():
  106. x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
  107. def cases(x):
  108. x[..., 2, False, 1] = [-1]
  109. x[0, True, 0, None, True] = [-2, -2, -2, -2]
  110. x[0, ..., None] = [[-3], [-3], [-3], [-3]]
  111. x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
  112. return x
  113. setup_testcase(x, cases)
  114. @pytest.mark.level0
  115. @pytest.mark.platform_arm_ascend_training
  116. @pytest.mark.platform_x86_ascend_training
  117. @pytest.mark.platform_x86_gpu_training
  118. @pytest.mark.env_onecard
  119. def test_setitem_by_nested_unit_list():
  120. x = onp.arange(24).reshape(2, 3, 4).astype(onp.float32)
  121. def cases(x):
  122. x[[[[0]]], True] = -1
  123. x[[1], ..., [[[[2]]]]] = -2
  124. x[0, [[[2]]], [1]] = -3
  125. return x
  126. setup_testcase(x, cases)
  127. @pytest.mark.level0
  128. @pytest.mark.platform_arm_ascend_training
  129. @pytest.mark.platform_x86_ascend_training
  130. @pytest.mark.platform_x86_gpu_training
  131. @pytest.mark.env_onecard
  132. def test_setitem_with_broadcast():
  133. x = onp.arange(2*3*4*5*6).reshape(2, 3, 4, 5, 6).astype(onp.float32)
  134. v1 = onp.full((1, 4, 5), -1).tolist()
  135. v2 = onp.full((4, 1, 6), -2).tolist()
  136. def cases(x):
  137. x[..., 4] = v1
  138. x[0, 2] = v2
  139. x[1, 0, ..., 3] = [[-3], [-3], [-3], [-3]]
  140. x[0, ..., 1, 3, 5] = -4
  141. return x
  142. setup_testcase(x, cases)
  143. @pytest.mark.level0
  144. @pytest.mark.platform_arm_ascend_training
  145. @pytest.mark.platform_x86_ascend_training
  146. @pytest.mark.platform_x86_gpu_training
  147. @pytest.mark.env_onecard
  148. def test_setitem_mul_by_scalar():
  149. x = onp.ones((4, 5), dtype=onp.float32)
  150. def cases(x):
  151. x[1, :] = x[1, :]*2
  152. x[:, 2] = x[:, 3]*3.0
  153. return x
  154. setup_testcase(x, cases)
  155. @pytest.mark.level0
  156. @pytest.mark.platform_arm_ascend_training
  157. @pytest.mark.platform_x86_ascend_training
  158. @pytest.mark.platform_x86_gpu_training
  159. @pytest.mark.env_onecard
  160. def test_setitem_by_slice():
  161. x = onp.ones((3, 4, 5), dtype=onp.float32)
  162. def cases(x):
  163. x[1:2] = 2
  164. x[-3:1] = 3
  165. x[-10:3:2] = 4
  166. x[5:0:3] = 5
  167. x[5:5:5] = 6
  168. x[-1:2] = 7
  169. return x
  170. setup_testcase(x, cases)
  171. @pytest.mark.level0
  172. @pytest.mark.platform_arm_ascend_training
  173. @pytest.mark.platform_x86_ascend_training
  174. @pytest.mark.platform_x86_gpu_training
  175. @pytest.mark.env_onecard
  176. def test_setitem_by_tuple_of_slices():
  177. x = onp.ones((3, 4, 5), dtype=onp.float32)
  178. def cases(x):
  179. x[1:2, 2] = 2
  180. x[0, -4:1] = 3
  181. x[1, -10:3:2] = 4
  182. x[5:0:3, 3] = 5
  183. x[1:1, 2:2] = 6
  184. return x
  185. setup_testcase(x, cases)