You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_index.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_slice """
  16. import numpy as np
  17. import pytest
  18. from mindspore import Tensor, Parameter
  19. from mindspore import context
  20. from mindspore import dtype as mstype
  21. from mindspore.nn import Cell
  22. def setup_module():
  23. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  24. class NetWorkSlicePositive(Cell):
  25. def __init__(self):
  26. super(NetWorkSlicePositive, self).__init__()
  27. self.tensor_ret0 = Tensor(np.ones([1, 2, 3], np.int32))
  28. self.tensor_ret1 = Tensor(np.ones([4, 8, 10], np.int32))
  29. self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32))
  30. self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32))
  31. def construct(self, tensor):
  32. ret0 = tensor[3:4:1, 1:5:2, 3:6:1] + self.tensor_ret0
  33. ret1 = tensor[-6:4:1, 0:8:1, ::1] + self.tensor_ret1
  34. ret2 = tensor[::, ::, ::] + self.tensor_ret2
  35. ret3 = tensor[::2] + self.tensor_ret3
  36. return ret0, ret1, ret2, ret3
  37. def test_slice_positive():
  38. net = NetWorkSlicePositive()
  39. input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  40. input_0 = Tensor(input_np)
  41. output0, output1, output2, output3 = net(input_0)
  42. assert np.all(output0.asnumpy() == input_np[3:4:1, 1:5:2, 3:6:1] + np.ones([1, 2, 3]))
  43. assert np.all(output1.asnumpy() == input_np[-6:4:1, 0:8:1, ::1] + np.ones([4, 8, 10]))
  44. assert np.all(output2.asnumpy() == input_np[::, ::, ::] + np.ones([6, 8, 10]))
  45. assert np.all(output3.asnumpy() == input_np[::2] + np.ones([3, 8, 10]))
  46. class NetWorkSliceEllipsis(Cell):
  47. def __init__(self):
  48. super(NetWorkSliceEllipsis, self).__init__()
  49. self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
  50. self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
  51. self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
  52. def construct(self, tensor):
  53. ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
  54. ret1 = tensor[...] + self.tensor_ret1
  55. ret2 = tensor[None] + self.tensor_ret2
  56. ret3 = tensor[True] + self.tensor_ret2
  57. return ret0, ret1, ret2, ret3
  58. def Xtest_slice_ellipsis():
  59. net = NetWorkSliceEllipsis()
  60. input_np = np.arange(6*7*8*9).reshape(6, 7, 8, 9).astype(np.int32)
  61. input_0 = Tensor(input_np)
  62. output0, output1, output2, output3 = net(input_0)
  63. assert np.all(output0.asnumpy() == input_np[0:4:2, ..., 1] + np.ones([1, 2, 3]))
  64. assert np.all(output1.asnumpy() == input_np[...] + np.ones([6, 7, 8, 9]))
  65. assert np.all(output2.asnumpy() == input_np[None] + np.ones([6, 7, 8, 9]))
  66. assert np.all(output3.asnumpy() == input_np[True] + np.ones([1, 6, 7, 8, 9]))
  67. class NetWorkReduceDimension(Cell):
  68. def __init__(self):
  69. super(NetWorkReduceDimension, self).__init__()
  70. self.tensor_ret1 = Tensor(np.ones([3, 10], np.int32))
  71. self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32))
  72. self.tensor_ret3 = Tensor(np.array(8, np.int32))
  73. self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32))
  74. def construct(self, tensor):
  75. ret1 = tensor[::2, 1, ::1] + self.tensor_ret1
  76. ret2 = tensor[::, ::, 0] + self.tensor_ret2
  77. ret3 = tensor[3, 2, 5] + self.tensor_ret3
  78. ret4 = tensor[1] + self.tensor_ret4
  79. return ret1, ret2, ret3, ret4
  80. def Xtest_reduce_dimension():
  81. net = NetWorkReduceDimension()
  82. input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  83. input_0 = Tensor(input_np)
  84. output1, output2, output3, output4 = net(input_0)
  85. assert np.all(output1.asnumpy() == input_np[::2, 1, ::1] + np.ones([3, 10]))
  86. assert np.all(output2.asnumpy() == input_np[::, ::, 0] + np.ones([6, 8]))
  87. assert np.all(output3.asnumpy() == input_np[3, 2, 5] + np.array(8, np.int32))
  88. assert np.all(output4.asnumpy() == input_np[1] + np.ones([8, 10]))
  89. class NetWorkSliceStep(Cell):
  90. def __init__(self):
  91. super(NetWorkSliceStep, self).__init__()
  92. self.tensor_ret1 = Tensor(np.ones([6, 5, 10], np.int32))
  93. self.tensor_ret2 = Tensor(np.ones([3, 5, 5], np.int32))
  94. def construct(self, tensor):
  95. ret1 = tensor[::1, -5::, ::-1] + self.tensor_ret1
  96. ret2 = tensor[::2, -5::, ::2] + self.tensor_ret2
  97. return ret1, ret2
  98. def Xtest_step_negative():
  99. net = NetWorkSliceEllipsis()
  100. input_np = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  101. input_0 = Tensor(input_np)
  102. output1, output2 = net(input_0)
  103. assert np.all(output1.asnumpy() == input_np[::1, -5::, ::-1] + np.ones([6, 8, 10]))
  104. assert np.all(output2.asnumpy() == input_np[::2, -5::, ::2] + np.ones([3, 5, 5]))
  105. class TensorGetItemByThreeTensors(Cell):
  106. def __init__(self):
  107. super(TensorGetItemByThreeTensors, self).__init__()
  108. self.const0 = Tensor(np.ones((4, 5, 8, 10)), mstype.int32)
  109. self.const1 = Tensor(np.ones((3, 4, 5, 10)), mstype.int32)
  110. self.const2 = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
  111. def construct(self, x, index_0, index_1, index_2):
  112. ret0 = x[index_0] + self.const0
  113. ret1 = x[index_0, index_1] + self.const1
  114. ret2 = x[index_0, index_1, index_2] + self.const2
  115. return ret0, ret1, ret2
  116. def Xtest_getitem_by_tensors():
  117. net = TensorGetItemByThreeTensors()
  118. input_x = np.arange(6*8*10).reshape(6, 8, 10).astype(np.int32)
  119. index_0 = np.random.randint(6, size=(3, 4, 5)).astype(np.int32)
  120. index_1 = np.random.randint(6, size=(4, 5)).astype(np.int32)
  121. index_2 = np.random.randint(6, size=(5, 3, 4, 5)).astype(np.int32)
  122. input_x_ms = Tensor(input_x)
  123. index_0_ms = Tensor(index_0)
  124. index_1_ms = Tensor(index_1)
  125. input_2_ms = Tensor(index_2)
  126. output0, output1, output2 = net(input_x_ms, index_0_ms, index_1_ms, input_2_ms)
  127. assert np.all(output0.asnumpy() == input_x[index_0] + np.ones([4, 5, 8, 10]))
  128. assert np.all(output1.asnumpy() == input_x[index_0, index_1] + np.ones([3, 4, 5, 10]))
  129. assert np.all(output2.asnumpy() == input_x[index_0, index_1, index_2] + np.ones([5, 3, 4, 5]))
  130. class TensorGetItemByMixedTensors_0(Cell):
  131. def __init__(self):
  132. super(TensorGetItemByMixedTensors_0, self).__init__()
  133. self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
  134. def construct(self, tensor, index_0, index_1):
  135. ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
  136. return ret
  137. class TensorGetItemByMixedTensors_1(Cell):
  138. def __init__(self):
  139. super(TensorGetItemByMixedTensors_1, self).__init__()
  140. self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
  141. def construct(self, tensor, index_0, index_1):
  142. ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
  143. return ret
  144. class TensorGetItemByMixedTensors_2(Cell):
  145. def __init__(self):
  146. super(TensorGetItemByMixedTensors_2, self).__init__()
  147. self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
  148. def construct(self, tensor, index_0, index_1):
  149. ret = tensor[0, index_0, index_1, ..., 3] + self.const
  150. return ret
  151. class TensorGetItemByMixedTensors_3(Cell):
  152. def __init__(self):
  153. super(TensorGetItemByMixedTensors_3, self).__init__()
  154. self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
  155. def construct(self, tensor, index_0, index_1):
  156. ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
  157. return ret
  158. class TensorGetItemByMixedTensors_4(Cell):
  159. def __init__(self):
  160. super(TensorGetItemByMixedTensors_4, self).__init__()
  161. self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
  162. def construct(self, tensor, index_0, index_1, index_2):
  163. ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
  164. return ret
  165. class TensorGetItemByMixedTensors_5(Cell):
  166. def __init__(self):
  167. super(TensorGetItemByMixedTensors_5, self).__init__()
  168. self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
  169. def construct(self, tensor, index_0, index_1, index_2):
  170. ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
  171. return ret
  172. class TensorGetItemByMixedTensors_6(Cell):
  173. def __init__(self):
  174. super(TensorGetItemByMixedTensors_6, self).__init__()
  175. self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
  176. def construct(self, tensor, index_0, index_1, index_2):
  177. ret = tensor[..., index_0, index_1, index_2, 3] + self.const
  178. return ret
  179. class TensorSetItemByMixedTensors_0(Cell):
  180. def __init__(self, value):
  181. super(TensorSetItemByMixedTensors_0, self).__init__()
  182. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
  183. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
  184. mstype.float32),
  185. name="x")
  186. self.value = value
  187. def construct(self, index_0, index_1, index_2):
  188. self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
  189. ret = self.param + self.const
  190. return ret
  191. class TensorSetItemByMixedTensors_1(Cell):
  192. def __init__(self, value):
  193. super(TensorSetItemByMixedTensors_1, self).__init__()
  194. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
  195. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  196. name="x")
  197. self.value = value
  198. def construct(self, index_0, index_1, index_2):
  199. self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
  200. ret = self.param + self.const
  201. return ret
  202. class TensorSetItemByMixedTensors_2(Cell):
  203. def __init__(self, value):
  204. super(TensorSetItemByMixedTensors_2, self).__init__()
  205. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
  206. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
  207. name="x")
  208. self.value = value
  209. def construct(self, index_0, index_1, index_2):
  210. self.param[..., index_0, index_1, index_2, 3] = self.value
  211. ret = self.param + self.const
  212. return ret
  213. class TensorGetItemByMixedTensorsTypeError(Cell):
  214. def __init__(self):
  215. super(TensorGetItemByMixedTensorsTypeError, self).__init__()
  216. def construct(self, x, index_0, index_1):
  217. ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
  218. return ret
  219. class TensorGetItemByMixedTensorsNumberError(Cell):
  220. def __init__(self):
  221. super(TensorGetItemByMixedTensorsNumberError, self).__init__()
  222. def construct(self, x, index_0, index_1):
  223. ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
  224. return ret
  225. class TensorSetItemByOneTensorWithNumber(Cell):
  226. def __init__(self, value):
  227. super(TensorSetItemByOneTensorWithNumber, self).__init__()
  228. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  229. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  230. self.value = value
  231. def construct(self, index):
  232. self.param[index] = self.value
  233. ret = self.param + self.const
  234. return ret
  235. class TensorSetItemByOneTensorWithTensor(Cell):
  236. def __init__(self):
  237. super(TensorSetItemByOneTensorWithTensor, self).__init__()
  238. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  239. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  240. def construct(self, index, value):
  241. self.param[index] = value
  242. ret = self.param + self.const
  243. return ret
  244. class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
  245. def __init__(self, value):
  246. super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
  247. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  248. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  249. self.value = value
  250. def construct(self, index):
  251. self.param[index] = self.value
  252. ret = self.param + self.const
  253. return ret
  254. class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
  255. def __init__(self):
  256. super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
  257. self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
  258. self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")
  259. def construct(self, index, value_0, value_1, value_2):
  260. self.param[index] = (value_0, value_1, value_2)
  261. ret = self.param + self.const
  262. return ret
  263. class TensorSetItemByTensorsWithNumber(Cell):
  264. def __init__(self, value):
  265. super(TensorSetItemByTensorsWithNumber, self).__init__()
  266. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  267. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  268. self.value = value
  269. def construct(self, index_0, index_1, index_2):
  270. self.param[index_0, index_1, index_2] = self.value
  271. ret = self.param + self.const
  272. return ret
  273. class TensorSetItemByTensorsWithTensor(Cell):
  274. def __init__(self):
  275. super(TensorSetItemByTensorsWithTensor, self).__init__()
  276. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  277. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  278. def construct(self, index_0, index_1, index_2, value):
  279. self.param[index_0, index_1, index_2] = value
  280. ret = self.param + self.const
  281. return ret
  282. class TensorSetItemByTensorsWithTensorNumberError(Cell):
  283. def __init__(self):
  284. super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
  285. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  286. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  287. def construct(self, index_0, index_1, index_2, index_3, value):
  288. self.param[index_0, index_1, index_2, index_3] = value
  289. ret = self.param + self.const
  290. return ret
  291. class TensorSetItemByTensorsWithTupleOfNumber(Cell):
  292. def __init__(self, value):
  293. super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
  294. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  295. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  296. self.value = value
  297. def construct(self, index_0, index_1, index_2):
  298. self.param[index_0, index_1, index_2] = self.value
  299. ret = self.param + self.const
  300. return ret
  301. class TensorSetItemByTensorsWithTupleOfTensor(Cell):
  302. def __init__(self):
  303. super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
  304. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  305. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  306. def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
  307. self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
  308. ret = self.param + self.const
  309. return ret
  310. class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
  311. def __init__(self):
  312. super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
  313. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  314. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  315. def construct(self, index_0, index_1, index_2, value_0, value_1):
  316. self.param[index_0, index_1, index_2] = (value_0, value_1)
  317. ret = self.param + self.const
  318. return ret
  319. class TensorSetItemByMixedTensors(Cell):
  320. def __init__(self):
  321. super(TensorSetItemByMixedTensors, self).__init__()
  322. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  323. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  324. self.value = 99.0
  325. def construct(self, index_0, index_1):
  326. self.param[index_0, index_1, 0:6] = self.value
  327. ret = self.param + self.const
  328. return ret
  329. class TensorAssignWithSliceError1(Cell):
  330. def __init__(self):
  331. super(TensorAssignWithSliceError1, self).__init__()
  332. def construct(self, a, b):
  333. a[1:3:-1, ::] = b
  334. return a
  335. class TensorAssignWithSliceError2(Cell):
  336. def __init__(self):
  337. super(TensorAssignWithSliceError2, self).__init__()
  338. def construct(self, a, b):
  339. a[1:3:-1] = b
  340. return a
  341. class TensorAssignWithSlice2(Cell):
  342. def __init__(self):
  343. super(TensorAssignWithSlice2, self).__init__()
  344. def construct(self, a, b, ck):
  345. a[1:5] = b
  346. a[3:4] = 5
  347. a[-1:1:-1] = b
  348. a[-1:3:-1] = 5
  349. a[::] = b
  350. a[::] = 9
  351. z = a + ck
  352. return z
  353. class TensorAssignWithSlice(Cell):
  354. def __init__(self):
  355. super(TensorAssignWithSlice, self).__init__()
  356. self.c = 2
  357. def construct(self, a, b, ck):
  358. a[1:3, ::] = b
  359. a[2:3:, 3:] = b
  360. a[::] = b
  361. a[::] = self.c
  362. a[::, ::] = b
  363. a[::, ::] = self.c
  364. a[2:3:, 0:, 4:1:-1] = b
  365. a[2:3:, 0:, 4:1:-1] = self.c
  366. z = a + ck
  367. return z
  368. def test_tensor_assign():
  369. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  370. net = TensorAssignWithSlice()
  371. net2 = TensorAssignWithSlice2()
  372. net_e1 = TensorAssignWithSliceError1()
  373. net_e2 = TensorAssignWithSliceError2()
  374. a = np.arange(60).reshape(3, 4, 5)
  375. ck = np.arange(60).reshape(3, 4, 5)
  376. b = Tensor([1], dtype=mstype.float32)
  377. Ta = Tensor(a, dtype=mstype.float32)
  378. Tck = Tensor(ck, dtype=mstype.float32)
  379. Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32)
  380. Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32)
  381. Tb = Tensor([1, 3], dtype=mstype.float32)
  382. Tc = Tensor([], dtype=mstype.float32)
  383. t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  384. tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  385. net(Ta, b, Tck)
  386. net2(t, b, tck)
  387. # Error for A[Slice] = Number
  388. # 1. A[Slice] = Number, Slice error
  389. with pytest.raises(IndexError):
  390. net_e2(t, 2)
  391. # Error for A[Slice] = U, U is a Tensor
  392. # 1. A[Slice] = U, u.size is error
  393. with pytest.raises(ValueError):
  394. net2(t, Tb, tck)
  395. # 2. A[Slice] = U, U is empty
  396. with pytest.raises(ValueError):
  397. net2(t, Tc, tck)
  398. # 3. A[Slice] = U, U.size error
  399. with pytest.raises(ValueError):
  400. net2(t, Tb, tck)
  401. # Error for A[Tuple(Slice...)] = Tensor
  402. # 1. A[Tuple(Slice...)] = U, U is empty
  403. with pytest.raises(ValueError):
  404. net(Ta, Tc, Tck)
  405. # 2. A[Tuple(Slice...)] = U, U.size error
  406. with pytest.raises(ValueError):
  407. net(Ta, Tb, Tck)
  408. # 3. A[Tuple(Slice...)] = U, Slice error
  409. with pytest.raises(IndexError):
  410. net_e1(Ta, b)
  411. # Error for A[Tuple(Slice...)] = Number
  412. # 1. A[Tuple(Slice...)] = Number, Slice error
  413. with pytest.raises(IndexError):
  414. net_e1(Ta, 2)
  415. net = TensorAssignWithInteger()
  416. # Error for A[Number] = scalar/Tensor
  417. # 1. A[Number] = U, U is a Tensor, u.size not match
  418. with pytest.raises(ValueError):
  419. net(Ta, Tb, Tck)
  420. with pytest.raises(ValueError):
  421. net(Ta, Tc, Tck)
  422. # 2. A[Number] = U, the number index error
  423. with pytest.raises(IndexError):
  424. net(Ta4d, b, Ta4d_ck)
  425. # Error for A[(n,m)] = scalar/Tensor
  426. # 1. A[(n,m)] = U, U is a tensor. u.size not match
  427. net = TensorAssignWithTupleInteger()
  428. with pytest.raises(ValueError):
  429. net(Ta, Tc, Tck)
  430. with pytest.raises(ValueError):
  431. net(Ta, Tb, Tck)
  432. # 2. A[(n,m)] = U, the number index error
  433. with pytest.raises(IndexError):
  434. net(Ta4d, b, Ta4d_ck)
  435. # Error for A[...] = U or A[1:, ...] = u
  436. # 1. A[...] = scalar/tensor
  437. net = TensorAssignWithEllipsis()
  438. net(Ta, Ta4d)
  439. with pytest.raises(ValueError):
  440. net(Ta, Tc)
  441. with pytest.raises(ValueError):
  442. net(Ta, Tb)
  443. # 2. A[::, 1:, ...] = scalar/tensor
  444. net = TensorAssignWithTupleEllipsis()
  445. net(Ta, b)
  446. Tc = Tensor(1, mstype.float32)
  447. with pytest.raises(ValueError):
  448. net(Ta, Tc)
  449. with pytest.raises(ValueError):
  450. net(Ta, Tb)
  451. class TensorAssignWithTupleEllipsis2(Cell):
  452. def __init__(self):
  453. super(TensorAssignWithTupleEllipsis2, self).__init__()
  454. def construct(self, a, b):
  455. a[1:, ..., ::] = b
  456. return a
  457. class TensorAssignWithTupleEllipsis(Cell):
  458. def __init__(self):
  459. super(TensorAssignWithTupleEllipsis, self).__init__()
  460. def construct(self, a, b):
  461. a[:2, ...] = 1
  462. a[1:, ...] = b
  463. return a
  464. class TensorAssignWithEllipsis(Cell):
  465. def __init__(self):
  466. super(TensorAssignWithEllipsis, self).__init__()
  467. def construct(self, a, b):
  468. a[...] = 1
  469. a[...] = b
  470. return a
  471. class TensorAssignWithInteger(Cell):
  472. def __init__(self):
  473. super(TensorAssignWithInteger, self).__init__()
  474. def construct(self, a, b, ck):
  475. a[1] = 1
  476. a[0] = b
  477. z = a + ck
  478. return z
  479. class TensorAssignWithTupleInteger(Cell):
  480. def __init__(self):
  481. super(TensorAssignWithTupleInteger, self).__init__()
  482. def construct(self, a, b, ck):
  483. a[(1)] = 1
  484. a[(1)] = b
  485. a[(1, 1)] = b
  486. a[(1, 1)] = 1
  487. z = a + ck
  488. return z
  489. class TensorAssignWithBoolTensorIndex(Cell):
  490. def __init__(self):
  491. super(TensorAssignWithBoolTensorIndex, self).__init__()
  492. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  493. self.u_scalar = 5
  494. def construct(self, a, b, c, u_tensor):
  495. a[c] = self.u_scalar
  496. a[b] = u_tensor
  497. z = a + self.t
  498. return z
  499. class TensorAssignWithBoolTensorIndexError(Cell):
  500. def __init__(self):
  501. super(TensorAssignWithBoolTensorIndexError, self).__init__()
  502. def construct(self, a, b, c, u_tensor):
  503. a[b][c] = u_tensor
  504. return a
  505. class TensorAssignWithBoolTensorIndex2(Cell):
  506. def __init__(self):
  507. super(TensorAssignWithBoolTensorIndex2, self).__init__()
  508. self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
  509. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  510. self.u_scalar = 5
  511. def construct(self, a, u_tensor):
  512. a[a > 8] = u_tensor
  513. a[a >= 6] = self.u_scalar
  514. a[a < 3] = self.u_scalar
  515. a[a <= 5] = u_tensor
  516. a[a == 5] = self.u_scalar
  517. z = a + self.t
  518. return z
  519. class TensorAssignWithBoolTensorIndex2Error(Cell):
  520. def __init__(self):
  521. super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
  522. def construct(self, a, u_tensor):
  523. a[a > 8][a > 5] = u_tensor
  524. return a
  525. def Xtest_tensor_assign_bool_index():
  526. a = np.arange(60).reshape(3, 4, 5)
  527. b = a > 5
  528. c = a < 3
  529. Ta = Tensor(a, dtype=mstype.float32)
  530. Tb = Tensor(b)
  531. Tc = Tensor(c)
  532. Td = Tensor([True, True])
  533. u_tensor = Tensor([1], dtype=mstype.float32)
  534. u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
  535. u_scalar = 5
  536. net1 = TensorAssignWithBoolTensorIndex()
  537. net2 = TensorAssignWithBoolTensorIndex2()
  538. net1(Ta, Tb, Tc, u_tensor)
  539. net1(Ta, Tb, Tc, u_tensor)
  540. with pytest.raises(ValueError):
  541. net1(Ta, Td, Tc, u_tensor)
  542. with pytest.raises(IndexError):
  543. net1(Ta, u_tensor, Tc, u_tensor)
  544. with pytest.raises(ValueError):
  545. net1(Ta, Tb, Td, u_tensor)
  546. with pytest.raises(IndexError):
  547. net1(Ta, Tb, Ta, u_tensor)
  548. with pytest.raises(ValueError):
  549. net1(Ta, Tb, Tc, u_tensor_error)
  550. # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
  551. with pytest.raises(ValueError):
  552. net2(Ta, u_tensor_error)
  553. net3 = TensorAssignWithBoolTensorIndexError()
  554. with pytest.raises(AttributeError):
  555. net3(Ta, Tb, Tc, u_tensor)
  556. with pytest.raises(AttributeError):
  557. net3(Ta, Tb, Tc, u_scalar)
  558. net4 = TensorAssignWithBoolTensorIndex2Error()
  559. with pytest.raises(AttributeError):
  560. net4(Ta, u_tensor)
  561. with pytest.raises(AttributeError):
  562. net4(Ta, u_scalar)
  563. def Xtest_tensor_slice_reduce_out_of_bounds_neg():
  564. class NetWork(Cell):
  565. def __init__(self):
  566. super(NetWork, self).__init__()
  567. self.tensor_ret = Tensor(np.array(9, np.int32))
  568. def construct(self, tensor):
  569. ret = tensor[-7, 3, 4]
  570. return ret
  571. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  572. net = NetWork()
  573. with pytest.raises(ValueError) as ex:
  574. net(input_tensor)
  575. assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
  576. ex.value)
  577. def Xtest_tensor_slice_reduce_out_of_bounds_positive():
  578. class NetWork(Cell):
  579. def __init__(self):
  580. super(NetWork, self).__init__()
  581. self.tensor_ret = Tensor(np.array(9, np.int32))
  582. def construct(self, tensor):
  583. ret = tensor[6, 3, 4]
  584. return ret
  585. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  586. net = NetWork()
  587. with pytest.raises(ValueError) as ex:
  588. net(input_tensor)
  589. assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)