You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_index.py 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_slice """
  16. import numpy as np
  17. import pytest
  18. from mindspore import Tensor, Parameter
  19. from mindspore import context
  20. from mindspore import dtype as mstype
  21. from mindspore.nn import Cell
  22. from mindspore.common.parameter import ParameterTuple
  23. from mindspore.ops import composite as C
  24. def setup_module():
  25. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  26. class NetWorkSlicePositive(Cell):
  27. def __init__(self):
  28. super(NetWorkSlicePositive, self).__init__()
  29. self.tensor_ret0 = Tensor(np.ones([1, 2, 3], np.int32))
  30. self.tensor_ret1 = Tensor(np.ones([4, 8, 10], np.int32))
  31. self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32))
  32. self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32))
  33. def construct(self, tensor):
  34. ret0 = tensor[3:4:1, 1:5:2, 3:6:1] + self.tensor_ret0
  35. ret1 = tensor[-6:4:1, 0:8:1, ::1] + self.tensor_ret1
  36. ret2 = tensor[::, ::, ::] + self.tensor_ret2
  37. ret3 = tensor[::2] + self.tensor_ret3
  38. return ret0, ret1, ret2, ret3
  39. def test_slice_positive():
  40. net = NetWorkSlicePositive()
  41. input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  42. input_0 = Tensor(input_np)
  43. output0, output1, output2, output3 = net(input_0)
  44. assert np.all(output0.asnumpy() == input_np[3:4:1, 1:5:2, 3:6:1] + np.ones([1, 2, 3]))
  45. assert np.all(output1.asnumpy() == input_np[-6:4:1, 0:8:1, ::1] + np.ones([4, 8, 10]))
  46. assert np.all(output2.asnumpy() == input_np[::, ::, ::] + np.ones([6, 8, 10]))
  47. assert np.all(output3.asnumpy() == input_np[::2] + np.ones([3, 8, 10]))
  48. class NetWorkSliceEllipsis(Cell):
  49. def __init__(self):
  50. super(NetWorkSliceEllipsis, self).__init__()
  51. self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
  52. self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
  53. self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
  54. def construct(self, tensor):
  55. ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
  56. ret1 = tensor[...] + self.tensor_ret1
  57. ret2 = tensor[None] + self.tensor_ret2
  58. ret3 = tensor[True] + self.tensor_ret2
  59. return ret0, ret1, ret2, ret3
  60. def Xtest_slice_ellipsis():
  61. net = NetWorkSliceEllipsis()
  62. input_np = np.arange(6*7*8*9).reshape(6, 7, 8, 9).astype(np.int32)
  63. input_0 = Tensor(input_np)
  64. output0, output1, output2, output3 = net(input_0)
  65. assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([1, 2, 3]))
  66. assert np.all(output1.asnumpy() == input_np[...] + np.ones([6, 7, 8, 9]))
  67. assert np.all(output2.asnumpy() == input_np[None] + np.ones([6, 7, 8, 9]))
  68. assert np.all(output3.asnumpy() == input_np[True] + np.ones([1, 6, 7, 8, 9]))
  69. class NetWorkReduceDimension(Cell):
  70. def __init__(self):
  71. super(NetWorkReduceDimension, self).__init__()
  72. self.tensor_ret1 = Tensor(np.ones([3, 10], np.int32))
  73. self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32))
  74. self.tensor_ret3 = Tensor(np.array(8, np.int32))
  75. self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32))
  76. def construct(self, tensor):
  77. ret1 = tensor[::2, 1, ::1] + self.tensor_ret1
  78. ret2 = tensor[::, ::, 0] + self.tensor_ret2
  79. ret3 = tensor[3, 2, 5] + self.tensor_ret3
  80. ret4 = tensor[1] + self.tensor_ret4
  81. return ret1, ret2, ret3, ret4
  82. def Xtest_reduce_dimension():
  83. net = NetWorkReduceDimension()
  84. input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  85. input_0 = Tensor(input_np)
  86. output1, output2, output3, output4 = net(input_0)
  87. assert np.all(output1.asnumpy() == input_np[::2, 1, ::1] + np.ones([3, 10]))
  88. assert np.all(output2.asnumpy() == input_np[::, ::, 0] + np.ones([6, 8]))
  89. assert np.all(output3.asnumpy() == input_np[3, 2, 5] + np.array(8, np.int32))
  90. assert np.all(output4.asnumpy() == input_np[1] + np.ones([8, 10]))
  91. class NetWorkSliceStep(Cell):
  92. def __init__(self):
  93. super(NetWorkSliceStep, self).__init__()
  94. self.tensor_ret1 = Tensor(np.ones([6, 5, 10], np.int32))
  95. self.tensor_ret2 = Tensor(np.ones([3, 5, 5], np.int32))
  96. def construct(self, tensor):
  97. ret1 = tensor[::1, -5::, ::-1] + self.tensor_ret1
  98. ret2 = tensor[::2, -5::, ::2] + self.tensor_ret2
  99. return ret1, ret2
  100. def Xtest_step_negative():
  101. net = NetWorkSliceEllipsis()
  102. input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  103. input_0 = Tensor(input_np)
  104. output1, output2 = net(input_0)
  105. assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 8, 10]))
  106. assert np.all(output2.asnumpy() == input_np[::2, -5::, ::2] + np.ones([3, 5, 5]))
  107. class TensorGetItemByThreeTensors(Cell):
  108. def __init__(self):
  109. super(TensorGetItemByThreeTensors, self).__init__()
  110. self.const0 = Tensor(np.ones((4, 5, 8, 10)), mstype.int32)
  111. self.const1 = Tensor(np.ones((3, 4, 5, 10)), mstype.int32)
  112. self.const2 = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
  113. def construct(self, x, index_0, index_1, index_2):
  114. ret0 = x[index_0] + self.const0
  115. ret1 = x[index_0, index_1] + self.const1
  116. ret2 = x[index_0, index_1, index_2] + self.const2
  117. return ret0, ret1, ret2
  118. def test_getitem_by_tensors():
  119. net = TensorGetItemByThreeTensors()
  120. input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  121. index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
  122. index_1 = np.random.randint(6, size=(4, 5)).astype(np.int32)
  123. index_2 = np.random.randint(6, size=(5, 3, 4, 5)).astype(np.int32)
  124. input_x_ms = Tensor(input_x)
  125. index_0_ms = Tensor(index_0)
  126. index_1_ms = Tensor(index_1)
  127. input_2_ms = Tensor(index_2)
  128. output0, output1, output2 = net(input_x_ms, index_0_ms, index_1_ms, input_2_ms)
  129. assert np.all(output0.asnumpy() == input_x[index_0] + np.ones([4, 5, 8, 10]))
  130. assert np.all(output1.asnumpy() == input_x[index_0, index_1] + np.ones([3, 4, 5, 10]))
  131. assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5]))
  132. class TensorGetItemByMixedTensorsBasicCase(Cell):
  133. def __init__(self, c0, c1, c2, c3, c4, c5):
  134. super(TensorGetItemByMixedTensorsBasicCase, self).__init__()
  135. self.const0 = Tensor(c0)
  136. self.const1 = Tensor(c1)
  137. self.const2 = Tensor(c2)
  138. self.const3 = Tensor(c3)
  139. self.const4 = Tensor(c4)
  140. self.const5 = Tensor(c5)
  141. def construct(self, tensor, index_0, index_1):
  142. ret0 = tensor[index_0, index_1, 0:3] + self.const0
  143. ret1 = tensor[0:3, index_0, ...] + self.const1
  144. ret2 = tensor[0, index_0, index_1] + self.const2
  145. ret3 = tensor[..., index_0, 0:3] + self.const3
  146. ret4 = tensor[0:2, index_0, index_1] + self.const4
  147. ret5 = tensor[..., index_0, index_1] + self.const5
  148. return ret0, ret1, ret2, ret3, ret4, ret5
  149. def test_getitem_by_mixed_tensors():
  150. const0 = np.ones((3, 4, 5, 3), np.float32)
  151. const1 = np.ones((3, 3, 4, 5, 5), np.float32)
  152. const2 = np.ones((3, 4, 5), np.float32)
  153. const3 = np.ones((3, 3, 4, 5, 3), np.float32)
  154. const4 = np.ones((2, 3, 4, 5), np.float32)
  155. const5 = np.ones((3, 3, 4, 5), np.float32)
  156. net = TensorGetItemByMixedTensorsBasicCase(const0, const1, const2, const3, const4, const5)
  157. input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
  158. input_ms = Tensor(input_np, mstype.float32)
  159. index_np_0 = np.random.randint(3, size=(3, 4, 5)).astype(np.int32)
  160. index_np_1 = np.random.randint(4, size=(4, 5)).astype(np.int32)
  161. index_0 = Tensor(index_np_0, mstype.int32)
  162. index_1 = Tensor(index_np_1, mstype.int32)
  163. out0, out1, out2, out3, out4, out5 = net(input_ms, index_0, index_1)
  164. assert np.all(out0.asnumpy() == (input_np[index_np_0, index_np_1, 0:3] + const0))
  165. assert np.all(out1.asnumpy() == (input_np[0:3, index_np_0, ...] + const1))
  166. assert np.all(out2.asnumpy() == (input_np[0, index_np_0, index_np_1] + const2))
  167. assert np.all(out3.asnumpy() == (input_np[..., index_np_0, 0:3] + const3))
  168. assert np.all(out4.asnumpy() == (input_np[0:2, index_np_0, index_np_1] + const4))
  169. assert np.all(out5.asnumpy() == (input_np[..., index_np_0, index_np_1] + const5))
  170. class TensorSetItemByMixedTensors_0(Cell):
  171. def __init__(self, value):
  172. super(TensorSetItemByMixedTensors_0, self).__init__()
  173. self.const = Tensor(np.ones((3, 4, 5), np.float32))
  174. self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)),
  175. mstype.float32),
  176. name="x")
  177. self.value = value
  178. def construct(self, index_0, index_1, index_2):
  179. self.param[0:2, index_0, index_1] = self.value
  180. ret = self.param + self.const
  181. return ret
  182. def test_setitem_by_mixed_tensors_0():
  183. value = 88.0
  184. net = TensorSetItemByMixedTensors_0(value)
  185. index_0 = np.random.randint(3, size=(3, 4, 5))
  186. index_1 = np.random.randint(4, size=(4, 5))
  187. index_2 = np.random.randint(3, size=(2, 1, 4, 5))
  188. index_0_ms = Tensor(index_0, mstype.int32)
  189. index_1_ms = Tensor(index_1, mstype.int32)
  190. index_2_ms = Tensor(index_2, mstype.int32)
  191. input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
  192. const = np.ones((3, 4, 5), np.float32)
  193. out = net(index_0_ms, index_1_ms, index_2_ms)
  194. input_np[0:2, index_0, index_1] = value
  195. assert np.all(out.asnumpy() == (input_np + const))
  196. class TensorSetItemByMixedTensors_1(Cell):
  197. def __init__(self, value):
  198. super(TensorSetItemByMixedTensors_1, self).__init__()
  199. self.const = Tensor(np.ones((3, 4, 5), np.float32))
  200. self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float32),
  201. name="x")
  202. self.value = value
  203. def construct(self, index_0, index_1, index_2):
  204. self.param[0:2, index_0, ...] = self.value
  205. ret = self.param + self.const
  206. return ret
  207. def test_setitem_by_mixed_tensors_1():
  208. value = 88.0
  209. net = TensorSetItemByMixedTensors_1(value)
  210. index_0 = np.random.randint(3, size=(3, 4, 5))
  211. index_1 = np.random.randint(4, size=(4, 5))
  212. index_2 = np.random.randint(3, size=(2, 1, 4, 5))
  213. index_0_ms = Tensor(index_0, mstype.int32)
  214. index_1_ms = Tensor(index_1, mstype.int32)
  215. index_2_ms = Tensor(index_2, mstype.int32)
  216. input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
  217. const = np.ones((3, 4, 5), np.float32)
  218. out = net(index_0_ms, index_1_ms, index_2_ms)
  219. input_np[0:2, index_0, ...] = value
  220. assert np.all(out.asnumpy() == (input_np + const))
  221. class TensorSetItemByMixedTensors_2(Cell):
  222. def __init__(self, value):
  223. super(TensorSetItemByMixedTensors_2, self).__init__()
  224. self.const = Tensor(np.ones((3, 4, 5), np.float16))
  225. self.param = Parameter(Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)), mstype.float16),
  226. name="x")
  227. self.value = value
  228. def construct(self, index_0, index_1, index_2):
  229. self.param[..., index_0, 1] = self.value
  230. ret = self.param + self.const
  231. return ret
  232. def test_setitem_by_mixed_tensors_2():
  233. value = 88.0
  234. net = TensorSetItemByMixedTensors_2(value)
  235. index_0 = np.random.randint(3, size=(3, 4, 5))
  236. index_1 = np.random.randint(4, size=(4, 5))
  237. index_2 = np.random.randint(3, size=(2, 1, 4, 5))
  238. index_0_ms = Tensor(index_0, mstype.int32)
  239. index_1_ms = Tensor(index_1, mstype.int32)
  240. index_2_ms = Tensor(index_2, mstype.int32)
  241. input_np = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32)
  242. const = np.ones((3, 4, 5), np.float32)
  243. out = net(index_0_ms, index_1_ms, index_2_ms)
  244. input_np[..., index_0, 1] = value
  245. assert np.all(out.asnumpy() == (input_np + const))
  246. class TensorGetItemByMixedTensorsTypeError(Cell):
  247. def __init__(self):
  248. super(TensorGetItemByMixedTensorsTypeError, self).__init__()
  249. def construct(self, x, index_0, index_1):
  250. ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
  251. return ret
  252. def test_getitem_by_mixedtensor_exception():
  253. input_ms = Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32)
  254. index_0 = Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32)
  255. index_1 = Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)
  256. net1 = TensorGetItemByMixedTensorsTypeError()
  257. with pytest.raises(TypeError):
  258. net1(input_ms, index_0, index_1)
  259. class TensorSetItemByOneTensorWithNumber(Cell):
  260. def __init__(self, value):
  261. super(TensorSetItemByOneTensorWithNumber, self).__init__()
  262. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  263. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  264. self.value = value
  265. def construct(self, index):
  266. self.param[index] = self.value
  267. ret = self.param + self.const
  268. return ret
  269. def test_setitem_one_tensor_with_number():
  270. value = 0.0
  271. net = TensorSetItemByOneTensorWithNumber(value)
  272. index_np = np.random.randint(4, size=(5, 4))
  273. index = Tensor(index_np, mstype.int32)
  274. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
  275. const = np.ones((6, 7, 8)).astype(np.float32)
  276. out = net(index)
  277. input_data[index_np] = value
  278. assert np.all(out.asnumpy() == (input_data + const))
  279. class TensorSetItemByOneTensorWithTensor(Cell):
  280. def __init__(self):
  281. super(TensorSetItemByOneTensorWithTensor, self).__init__()
  282. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  283. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  284. def construct(self, index, value):
  285. self.param[index] = value
  286. ret = self.param + self.const
  287. return ret
  288. def test_setitem_by_one_tensor_with_tensor():
  289. net = TensorSetItemByOneTensorWithTensor()
  290. index_np = np.random.randint(4, size=(5, 4))
  291. index = Tensor(index_np, mstype.int32)
  292. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8))
  293. const = np.ones((6, 7, 8)).astype(np.float32)
  294. value = np.zeros((4, 7, 8)).astype(np.float32)
  295. value_ms = Tensor(value, mstype.float32)
  296. out = net(index, value_ms)
  297. input_data[index_np] = value
  298. assert np.all(out.asnumpy() == (input_data + const))
  299. class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
  300. def __init__(self, value):
  301. super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
  302. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  303. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  304. self.value = value
  305. def construct(self, index):
  306. self.param[index] = self.value
  307. ret = self.param + self.const
  308. return ret
  309. def test_setitem_by_one_tensor_with_tuple_number():
  310. value = (0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)
  311. net = TensorSetItemByOneTensorWithTupleOfNumber(value)
  312. input_np = np.random.randint(5, size=(5, 4))
  313. input_ms = Tensor(input_np, mstype.int32)
  314. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
  315. const = np.ones((6, 7, 8)).astype(np.float32)
  316. out = net(input_ms)
  317. input_data[input_np] = value
  318. assert np.all(out.asnumpy() == (input_data + const))
  319. class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
  320. def __init__(self):
  321. super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
  322. self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
  323. self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")
  324. def construct(self, index, value_0, value_1, value_2):
  325. self.param[index] = (value_0, value_1, value_2)
  326. ret = self.param + self.const
  327. return ret
  328. def test_setitem_by_one_tensor_with_tuple_tensors():
  329. net = TensorSetItemByOneTensorWithTupleOfTensor()
  330. input_np = np.random.randint(6, size=(5, 4)).astype(np.int32)
  331. input_ms = Tensor(input_np, mstype.int32)
  332. input_data = np.arange(6 * 3 * 8).reshape((6, 3, 8)).astype(np.float32)
  333. value_0_np = np.zeros((8,), np.float32)
  334. value_1_np = np.ones((8,), np.float32)
  335. value_2_np = np.ones((8,), np.float32)*2
  336. value_0 = Tensor(value_0_np)
  337. value_1 = Tensor(value_1_np)
  338. value_2 = Tensor(value_2_np)
  339. const = np.ones((6, 3, 8)).astype(np.float32)
  340. out = net(input_ms, value_0, value_1, value_2)
  341. input_data[input_np] = (value_0_np, value_1_np, value_2_np)
  342. assert np.all(out.asnumpy() == (input_data + const))
  343. class TensorSetItemByTensorsWithNumber(Cell):
  344. def __init__(self, value):
  345. super(TensorSetItemByTensorsWithNumber, self).__init__()
  346. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  347. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  348. self.value = value
  349. def construct(self, index_0, index_1, index_2):
  350. self.param[index_0, index_1, index_2] = self.value
  351. ret = self.param + self.const
  352. return ret
  353. def test_setitem_by_tensors_with_number():
  354. value = 0.0
  355. net = TensorSetItemByTensorsWithNumber(value)
  356. index_0 = np.random.randint(6, size=(3, 4, 5))
  357. index_1 = np.random.randint(7, size=(4, 5))
  358. index_2 = np.random.randint(8, size=(5, 3, 4, 5))
  359. index_0_ms = Tensor(index_0, mstype.int32)
  360. index_1_ms = Tensor(index_1, mstype.int32)
  361. index_2_ms = Tensor(index_2, mstype.int32)
  362. out = net(index_0_ms, index_1_ms, index_2_ms)
  363. const = np.ones((6, 7, 8)).astype(np.float32)
  364. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
  365. input_data[index_0, index_1, index_2] = value
  366. assert np.all(out.asnumpy() == (input_data + const))
  367. class TensorSetItemByTensorsWithTensor(Cell):
  368. def __init__(self):
  369. super(TensorSetItemByTensorsWithTensor, self).__init__()
  370. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  371. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  372. def construct(self, index_0, index_1, index_2, value):
  373. self.param[index_0, index_1, index_2] = value
  374. ret = self.param + self.const
  375. return ret
  376. def test_setitem_by_tensors_with_tensor():
  377. net = TensorSetItemByTensorsWithTensor()
  378. index_0 = np.random.randint(6, size=(3, 4, 5))
  379. index_1 = np.random.randint(7, size=(4, 5))
  380. index_2 = np.random.randint(8, size=(5, 3, 4, 5))
  381. value = np.zeros((4, 5)).astype(np.float32)
  382. index_0_ms = Tensor(index_0, mstype.int32)
  383. index_1_ms = Tensor(index_1, mstype.int32)
  384. index_2_ms = Tensor(index_2, mstype.int32)
  385. value_ms = Tensor(value, mstype.float32)
  386. out = net(index_0_ms, index_1_ms, index_2_ms, value_ms)
  387. const = np.ones((6, 7, 8)).astype(np.float32)
  388. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
  389. input_data[index_0, index_1, index_2] = value
  390. assert np.all(out.asnumpy() == (input_data + const))
  391. class TensorSetItemByTensorsWithTensorNumberError(Cell):
  392. def __init__(self):
  393. super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
  394. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  395. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  396. def construct(self, index_0, index_1, index_2, index_3, value):
  397. self.param[index_0, index_1, index_2, index_3] = value
  398. ret = self.param + self.const
  399. return ret
  400. def test_setitem_by_tensors_with_tensor_error():
  401. index_0 = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
  402. index_1 = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
  403. index_2 = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
  404. index_3 = Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32)
  405. value = Tensor(np.zeros((2, 5)), mstype.float32)
  406. net = TensorSetItemByTensorsWithTensorNumberError()
  407. with pytest.raises(IndexError):
  408. net(index_0, index_1, index_2, index_3, value)
  409. class TensorSetItemByTensorsWithTupleOfNumber(Cell):
  410. def __init__(self, value):
  411. super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
  412. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  413. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  414. self.value = value
  415. def construct(self, index_0, index_1, index_2):
  416. self.param[index_0, index_1, index_2] = self.value
  417. ret = self.param + self.const
  418. return ret
  419. def test_setitem_by_tensors_with_tuple_of_number():
  420. value = (0.0, 1.1, 2.2, 3.3, 4.4)
  421. net = TensorSetItemByTensorsWithTupleOfNumber(value)
  422. index_0 = np.random.randint(6, size=(3, 4, 5))
  423. index_1 = np.random.randint(7, size=(4, 5))
  424. index_2 = np.random.randint(8, size=(5, 3, 4, 5))
  425. index_0_ms = Tensor(index_0, mstype.int32)
  426. index_1_ms = Tensor(index_1, mstype.int32)
  427. index_2_ms = Tensor(index_2, mstype.int32)
  428. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
  429. input_data[index_0, index_1, index_2] = value
  430. const = np.ones((6, 7, 8)).astype(np.float32)
  431. out = net(index_0_ms, index_1_ms, index_2_ms)
  432. assert np.all(out.asnumpy() == (input_data + const))
  433. class TensorSetItemByTensorsWithTupleOfTensor(Cell):
  434. def __init__(self):
  435. super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
  436. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  437. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  438. def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
  439. self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
  440. ret = self.param + self.const
  441. return ret
  442. def test_setitem_by_tensors_with_tuple_of_tensor():
  443. value_0 = np.zeros((4, 5))
  444. value_1 = np.ones((4, 5))
  445. value_2 = np.ones((4, 5)) * 2
  446. value_0_ms = Tensor(value_0, mstype.float32)
  447. value_1_ms = Tensor(value_1, mstype.float32)
  448. value_2_ms = Tensor(value_2, mstype.float32)
  449. net = TensorSetItemByTensorsWithTupleOfTensor()
  450. index_0 = np.random.randint(6, size=(3, 4, 5))
  451. index_1 = np.random.randint(7, size=(4, 5))
  452. index_2 = np.random.randint(8, size=(5, 3, 4, 5))
  453. index_0_ms = Tensor(index_0, mstype.int32)
  454. index_1_ms = Tensor(index_1, mstype.int32)
  455. index_2_ms = Tensor(index_2, mstype.int32)
  456. input_data = np.arange(6 * 7 * 8).reshape((6, 7, 8)).astype(np.float32)
  457. input_data[index_0, index_1, index_2] = (value_0, value_1, value_2)
  458. const = np.ones((6, 7, 8)).astype(np.float32)
  459. out = net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms, value_2_ms)
  460. assert np.all(out.asnumpy() == (input_data + const))
  461. class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
  462. def __init__(self):
  463. super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
  464. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  465. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  466. def construct(self, index_0, index_1, index_2, value_0, value_1):
  467. self.param[index_0, index_1, index_2] = (value_0, value_1)
  468. ret = self.param + self.const
  469. return ret
  470. def test_setitem_by_tensor_with_tuple_of_tensor_error():
  471. net = TensorSetItemByTensorsWithTupleOfTensorNumberError()
  472. index_0_ms = Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32)
  473. index_1_ms = Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)
  474. index_2_ms = Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)
  475. value_0 = np.zeros((4, 5))
  476. value_1 = np.ones((4, 5))
  477. value_0_ms = Tensor(value_0, mstype.float32)
  478. value_1_ms = Tensor(value_1, mstype.float32)
  479. with pytest.raises(ValueError):
  480. net(index_0_ms, index_1_ms, index_2_ms, value_0_ms, value_1_ms)
  481. def test_setitem_grad():
  482. class Net(Cell):
  483. def __init__(self):
  484. super(Net, self).__init__()
  485. self.weight = Parameter(
  486. Tensor(np.ones([4, 4, 5]), dtype=mstype.float32), "b1", requires_grad=True)
  487. def construct(self, a, b):
  488. a[1:3:1, ::] = b
  489. c = a + self.weight
  490. return c
  491. class GradNet(Cell):
  492. def __init__(self, net):
  493. super(GradNet, self).__init__()
  494. self.net = net
  495. self.weights = ParameterTuple(net.trainable_params())
  496. def construct(self, x, y, sens):
  497. return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
  498. net = GradNet(Net())
  499. x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
  500. y = Tensor(np.array([3]).astype(np.float32), mstype.float32)
  501. sens = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
  502. net(x, y, sens)
  503. class TensorAssignWithSliceError1(Cell):
  504. def __init__(self):
  505. super(TensorAssignWithSliceError1, self).__init__()
  506. def construct(self, a, b):
  507. a[1:3:-1, ::] = b
  508. return a
  509. class TensorAssignWithSliceError2(Cell):
  510. def __init__(self):
  511. super(TensorAssignWithSliceError2, self).__init__()
  512. def construct(self, a, b):
  513. a[1:3:-1] = b
  514. return a
  515. class TensorAssignWithSlice2(Cell):
  516. def __init__(self):
  517. super(TensorAssignWithSlice2, self).__init__()
  518. def construct(self, a, b, ck):
  519. a[1:5] = b
  520. a[3:4] = 5
  521. a[-1:1:-1] = b
  522. a[-1:3:-1] = 5
  523. a[::] = b
  524. a[::] = 9
  525. z = a + ck
  526. return z
  527. class TensorAssignWithSlice(Cell):
  528. def __init__(self):
  529. super(TensorAssignWithSlice, self).__init__()
  530. self.c = 2.0
  531. def construct(self, a, b, ck):
  532. a[1:3, ::] = b
  533. a[2:3:, 3:] = b
  534. a[::] = b
  535. a[::] = self.c
  536. a[::, ::] = b
  537. a[::, ::] = self.c
  538. a[2:3:, 0:, 4:1:-1] = b
  539. a[2:3:, 0:, 4:1:-1] = self.c
  540. z = a + ck
  541. return z
  542. def test_tensor_assign_slice_value_1():
  543. net = TensorAssignWithSlice()
  544. a = np.arange(60).reshape(3, 4, 5)
  545. ck = np.arange(60).reshape(3, 4, 5)
  546. b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32)
  547. tb = Tensor(b, dtype=mstype.float32)
  548. ta = Tensor(a, dtype=mstype.float32)
  549. tck = Tensor(ck, dtype=mstype.float32)
  550. out = net(ta, tb, tck)
  551. a[1:3, ::] = b
  552. a[2:3:, 3:] = b
  553. a[::] = b
  554. a[::] = 2.0
  555. a[::, ::] = b
  556. a[::, ::] = 2.0
  557. a[2:3:, 0:, 4:1:-1] = b
  558. a[2:3:, 0:, 4:1:-1] = 2.0
  559. z = a + ck
  560. assert np.all(z == out.asnumpy())
  561. def test_tensor_assign_slice_value_2():
  562. net2 = TensorAssignWithSlice2()
  563. a = np.array([1, 2, 3, 4, 5, 6, 7, 8])
  564. ck = np.array([1, 2, 3, 4, 5, 6, 7, 8])
  565. b = np.array([1]).astype(np.float32) # Tensor([1], dtype=mstype.float32)
  566. tb = Tensor(b, dtype=mstype.float32)
  567. ta = Tensor(a, dtype=mstype.float32)
  568. tck = Tensor(ck, dtype=mstype.float32)
  569. out = net2(ta, tb, tck)
  570. a[1:5] = b
  571. a[3:4] = 5
  572. a[-1:1:-1] = b
  573. a[-1:3:-1] = 5
  574. a[::] = b
  575. a[::] = 9
  576. z = a + ck
  577. assert np.all(z == out.asnumpy())
  578. def test_tensor_assign_exception():
  579. net = TensorAssignWithSlice()
  580. net2 = TensorAssignWithSlice2()
  581. net_e1 = TensorAssignWithSliceError1()
  582. net_e2 = TensorAssignWithSliceError2()
  583. a = np.arange(60).reshape(3, 4, 5)
  584. ck = np.arange(60).reshape(3, 4, 5)
  585. b = Tensor([1], dtype=mstype.float32)
  586. Ta = Tensor(a, dtype=mstype.float32)
  587. Tck = Tensor(ck, dtype=mstype.float32)
  588. Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32)
  589. Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32)
  590. Tb = Tensor([1, 3], dtype=mstype.float32)
  591. Tc = Tensor([], dtype=mstype.float32)
  592. t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  593. tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  594. # Error for A[Slice] = Number
  595. # 1. A[Slice] = Number, Slice error
  596. with pytest.raises(IndexError):
  597. net_e2(t, 2)
  598. # Error for A[Slice] = U, U is a Tensor
  599. # 1. A[Slice] = U, u.size is error
  600. with pytest.raises(ValueError):
  601. net2(t, Tb, tck)
  602. # 2. A[Slice] = U, U is empty
  603. with pytest.raises(ValueError):
  604. net2(t, Tc, tck)
  605. # 3. A[Slice] = U, U.size error
  606. with pytest.raises(ValueError):
  607. net2(t, Tb, tck)
  608. # Error for A[Tuple(Slice...)] = Tensor
  609. # 1. A[Tuple(Slice...)] = U, U is empty
  610. with pytest.raises(ValueError):
  611. net(Ta, Tc, Tck)
  612. # 2. A[Tuple(Slice...)] = U, U.size error
  613. with pytest.raises(ValueError):
  614. net(Ta, Tb, Tck)
  615. # 3. A[Tuple(Slice...)] = U, Slice error
  616. with pytest.raises(IndexError):
  617. net_e1(Ta, b)
  618. # Error for A[Tuple(Slice...)] = Number
  619. # 1. A[Tuple(Slice...)] = Number, Slice error
  620. with pytest.raises(IndexError):
  621. net_e1(Ta, 2)
  622. net = TensorAssignWithInteger()
  623. # Error for A[Number] = scalar/Tensor
  624. # 1. A[Number] = U, U is a Tensor, u.size not match
  625. with pytest.raises(ValueError):
  626. net(Ta, Tb, Tck)
  627. with pytest.raises(ValueError):
  628. net(Ta, Tc, Tck)
  629. # 2. A[Number] = U, the number index error
  630. with pytest.raises(IndexError):
  631. net(Ta4d, b, Ta4d_ck)
  632. # Error for A[(n,m)] = scalar/Tensor
  633. # 1. A[(n,m)] = U, U is a tensor. u.size not match
  634. net = TensorAssignWithTupleInteger()
  635. with pytest.raises(ValueError):
  636. net(Ta, Tc, Tck)
  637. with pytest.raises(ValueError):
  638. net(Ta, Tb, Tck)
  639. # 2. A[(n,m)] = U, the number index error
  640. with pytest.raises(IndexError):
  641. net(Ta4d, b, Ta4d_ck)
  642. # Error for A[...] = U or A[1:, ...] = u
  643. # 1. A[...] = scalar/tensor
  644. net = TensorAssignWithEllipsis()
  645. net(Ta, Ta4d)
  646. with pytest.raises(ValueError):
  647. net(Ta, Tc)
  648. with pytest.raises(ValueError):
  649. net(Ta, Tb)
  650. # 2. A[::, 1:, ...] = scalar/tensor
  651. net = TensorAssignWithTupleEllipsis()
  652. net(Ta, b)
  653. with pytest.raises(ValueError):
  654. net(Ta, Tb)
  655. class TensorAssignWithTupleEllipsis2(Cell):
  656. def __init__(self):
  657. super(TensorAssignWithTupleEllipsis2, self).__init__()
  658. def construct(self, a, b):
  659. a[1:, ..., ::] = b
  660. return a
  661. class TensorAssignWithTupleEllipsis(Cell):
  662. def __init__(self):
  663. super(TensorAssignWithTupleEllipsis, self).__init__()
  664. def construct(self, a, b):
  665. a[:2, ...] = 1.0
  666. a[1:, ...] = b
  667. return a
  668. class TensorAssignWithEllipsis(Cell):
  669. def __init__(self):
  670. super(TensorAssignWithEllipsis, self).__init__()
  671. def construct(self, a, b):
  672. a[...] = 1
  673. a[...] = b
  674. return a
  675. class TensorAssignWithInteger(Cell):
  676. def __init__(self):
  677. super(TensorAssignWithInteger, self).__init__()
  678. def construct(self, a, b, ck):
  679. a[1] = 1
  680. a[0] = b
  681. z = a + ck
  682. return z
  683. class TensorAssignWithTupleInteger(Cell):
  684. def __init__(self):
  685. super(TensorAssignWithTupleInteger, self).__init__()
  686. def construct(self, a, b, ck):
  687. a[(1)] = 1
  688. a[(1)] = b
  689. a[(1, 1)] = b
  690. a[(1, 1)] = 1
  691. z = a + ck
  692. return z
  693. class TensorAssignWithBoolTensorIndex(Cell):
  694. def __init__(self):
  695. super(TensorAssignWithBoolTensorIndex, self).__init__()
  696. self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  697. self.u_scalar = 5
  698. def construct(self, a, b, c, u_tensor):
  699. a[c] = self.u_scalar
  700. a[b] = u_tensor
  701. z = a + self.t
  702. return z
  703. class TensorAssignWithBoolTensorIndexError(Cell):
  704. def __init__(self):
  705. super(TensorAssignWithBoolTensorIndexError, self).__init__()
  706. def construct(self, a, b, c, u_tensor):
  707. a[b][c] = u_tensor
  708. return a
  709. class TensorAssignWithBoolTensorIndex2(Cell):
  710. def __init__(self):
  711. super(TensorAssignWithBoolTensorIndex2, self).__init__()
  712. self.t = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  713. self.u_scalar = 5
  714. def construct(self, a, u_tensor):
  715. a[a > 8] = u_tensor
  716. a[a >= 6] = self.u_scalar
  717. a[a < 3] = self.u_scalar
  718. a[a <= 5] = u_tensor
  719. a[a == 5] = self.u_scalar
  720. z = a + self.t
  721. return z
  722. class TensorAssignWithBoolTensorIndex2Error(Cell):
  723. def __init__(self):
  724. super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
  725. def construct(self, a, u_tensor):
  726. a[a > 8][a > 5] = u_tensor
  727. return a
  728. def test_tensor_assign_bool_index_0():
  729. a = np.arange(60).reshape(3, 4, 5)
  730. b = a > 5
  731. c = a < 3
  732. Ta = Tensor(a, dtype=mstype.float32)
  733. Tb = Tensor(b)
  734. Tc = Tensor(c)
  735. u_tensor = Tensor([1], dtype=mstype.float32)
  736. net1 = TensorAssignWithBoolTensorIndex()
  737. out = net1(Ta, Tb, Tc, u_tensor)
  738. res = np.arange(60).reshape(3, 4, 5)
  739. res[c] = 5
  740. res[b] = 1
  741. res = res + np.ones([3, 4, 5])
  742. assert np.all(out.asnumpy() == res)
  743. def test_tensor_assign_bool_index_1():
  744. a = np.arange(60).reshape(3, 4, 5)
  745. Ta = Tensor(a, dtype=mstype.float32)
  746. u_tensor = Tensor([1], dtype=mstype.float32)
  747. net2 = TensorAssignWithBoolTensorIndex2()
  748. out = net2(Ta, u_tensor)
  749. res = np.arange(60).reshape(3, 4, 5)
  750. res[res > 8] = 1
  751. res[res >= 6] = 5
  752. res[res < 3] = 5
  753. res[res <= 5] = 1
  754. res[res == 5] = 5
  755. res = res + np.ones([3, 4, 5])
  756. assert np.all(out.asnumpy() == res)
  757. def test_tensor_assign_bool_index_exception():
  758. a = np.arange(60).reshape(3, 4, 5)
  759. b = a > 5
  760. c = a < 3
  761. Ta = Tensor(a, dtype=mstype.float32)
  762. Tb = Tensor(b)
  763. Tc = Tensor(c)
  764. Td = Tensor([True, True])
  765. u_tensor = Tensor([1], dtype=mstype.float32)
  766. u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
  767. u_scalar = 5
  768. net1 = TensorAssignWithBoolTensorIndex()
  769. net2 = TensorAssignWithBoolTensorIndex2()
  770. with pytest.raises(ValueError):
  771. net1(Ta, Td, Tc, u_tensor)
  772. with pytest.raises(IndexError):
  773. net1(Ta, u_tensor, Tc, u_tensor)
  774. with pytest.raises(ValueError):
  775. net1(Ta, Tb, Td, u_tensor)
  776. with pytest.raises(IndexError):
  777. net1(Ta, Tb, Ta, u_tensor)
  778. with pytest.raises(ValueError):
  779. net1(Ta, Tb, Tc, u_tensor_error)
  780. # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
  781. with pytest.raises(ValueError):
  782. net2(Ta, u_tensor_error)
  783. net3 = TensorAssignWithBoolTensorIndexError()
  784. with pytest.raises(IndexError):
  785. net3(Ta, Tb, Tc, u_tensor)
  786. with pytest.raises(IndexError):
  787. net3(Ta, Tb, Tc, u_scalar)
  788. net4 = TensorAssignWithBoolTensorIndex2Error()
  789. with pytest.raises(IndexError):
  790. net4(Ta, u_tensor)
  791. with pytest.raises(IndexError):
  792. net4(Ta, u_scalar)
  793. def Xtest_tensor_slice_reduce_out_of_bounds_neg():
  794. class NetWork(Cell):
  795. def __init__(self):
  796. super(NetWork, self).__init__()
  797. self.tensor_ret = Tensor(np.array(9, np.int32))
  798. def construct(self, tensor):
  799. ret = tensor[-7, 3, 4]
  800. return ret
  801. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  802. net = NetWork()
  803. with pytest.raises(ValueError) as ex:
  804. net(input_tensor)
  805. assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
  806. ex.value)
  807. def Xtest_tensor_slice_reduce_out_of_bounds_positive():
  808. class NetWork(Cell):
  809. def __init__(self):
  810. super(NetWork, self).__init__()
  811. self.tensor_ret = Tensor(np.array(9, np.int32))
  812. def construct(self, tensor):
  813. ret = tensor[6, 3, 4]
  814. return ret
  815. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  816. net = NetWork()
  817. with pytest.raises(ValueError) as ex:
  818. net(input_tensor)
  819. assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)
  820. def test_tensor_range():
  821. a = np.arange(4*5*6).reshape(4, 5, 6).astype(np.float32)
  822. ta = Tensor(a, mstype.float32)
  823. ms_out = []
  824. for item in ta:
  825. ms_out.append(item)
  826. np_out = []
  827. for item in a:
  828. np_out.append(item)
  829. for i, elem in enumerate(ms_out):
  830. assert np.all(elem.asnumpy() == np_out[i])