You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_slice.py 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_slice """
  16. import numpy as np
  17. import pytest
  18. from mindspore import Tensor, Parameter
  19. from mindspore import context
  20. from mindspore import dtype as mstype
  21. from mindspore.nn import Cell
  22. from ....mindspore_test_framework.mindspore_test import mindspore_test
  23. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  24. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
  25. pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
  26. class NetWorkSlicePositive(Cell):
  27. def __init__(self):
  28. super(NetWorkSlicePositive, self).__init__()
  29. self.tensor_ret0 = Tensor(np.ones([1, 2, 2], np.int32))
  30. self.tensor_ret1 = Tensor(np.ones([4, 7, 4], np.int32))
  31. self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32))
  32. self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32))
  33. def construct(self, tensor):
  34. ret0 = tensor[3:4:3, 1:5:2, 3:6:2] + self.tensor_ret0
  35. ret1 = tensor[-6:4:1, 7:-8:-1, ::3] + self.tensor_ret1
  36. ret2 = tensor[::, ::, ::] + self.tensor_ret2
  37. ret3 = tensor[::2] + self.tensor_ret3
  38. return ret0, ret1, ret2, ret3
  39. class NetWorkSliceEllipsis(Cell):
  40. def __init__(self):
  41. super(NetWorkSliceEllipsis, self).__init__()
  42. self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
  43. self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
  44. self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
  45. def construct(self, tensor):
  46. ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
  47. ret1 = tensor[...] + self.tensor_ret1
  48. ret2 = tensor[None] + self.tensor_ret2
  49. ret3 = tensor[True] + self.tensor_ret2
  50. return ret0, ret1, ret2, ret3
  51. class NetWorkReduceDimension(Cell):
  52. def __init__(self):
  53. super(NetWorkReduceDimension, self).__init__()
  54. self.tensor_ret0 = Tensor(np.ones([2, 4, 1], np.int32))
  55. self.tensor_ret1 = Tensor(np.ones([3, 4], np.int32))
  56. self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32))
  57. self.tensor_ret3 = Tensor(np.array(8, np.int32))
  58. self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32))
  59. def construct(self, tensor):
  60. ret0 = tensor[0:6:3, 1:5:1, 3:5:2] + self.tensor_ret0
  61. ret1 = tensor[::2, 1, ::3] + self.tensor_ret1
  62. ret2 = tensor[::, ::, 0] + self.tensor_ret2
  63. ret3 = tensor[3, 2, 5] + self.tensor_ret3
  64. ret4 = tensor[1] + self.tensor_ret4
  65. return ret0, ret1, ret2, ret3, ret4
  66. class NetWorkStepNegative(Cell):
  67. def __init__(self):
  68. super(NetWorkStepNegative, self).__init__()
  69. self.tensor_ret = Tensor(np.ones([6, 5, 10], np.int32))
  70. def construct(self, tensor):
  71. ret = tensor[::1, -5::, ::-1] + self.tensor_ret
  72. return ret
  73. class NetWorkReduceToScalar(Cell):
  74. def __init__(self):
  75. super(NetWorkReduceToScalar, self).__init__()
  76. self.tensor_ret = Tensor(np.array(9, np.int32))
  77. def construct(self, tensor):
  78. ret = tensor[2, 3, 4] + self.tensor_ret
  79. return ret
  80. class TensorAssignWithSliceError1(Cell):
  81. def __init__(self):
  82. super(TensorAssignWithSliceError1, self).__init__()
  83. def construct(self, a, b):
  84. a[1:3:-1, ::] = b
  85. return a
  86. class TensorAssignWithSliceError2(Cell):
  87. def __init__(self):
  88. super(TensorAssignWithSliceError2, self).__init__()
  89. def construct(self, a, b):
  90. a[1:3:-1] = b
  91. return a
  92. class TensorAssignWithSlice2(Cell):
  93. def __init__(self):
  94. super(TensorAssignWithSlice2, self).__init__()
  95. def construct(self, a, b, ck):
  96. a[1:5] = b
  97. a[3:4] = 5
  98. a[-1:1:-1] = b
  99. a[-1:3:-1] = 5
  100. a[::] = b
  101. a[::] = 9
  102. z = a + ck
  103. return z
  104. class TensorAssignWithSlice(Cell):
  105. def __init__(self):
  106. super(TensorAssignWithSlice, self).__init__()
  107. self.c = 2
  108. def construct(self, a, b, ck):
  109. a[1:3, ::] = b
  110. a[2:3:, 3:] = b
  111. a[::] = b
  112. a[::] = self.c
  113. a[::, ::] = b
  114. a[::, ::] = self.c
  115. a[2:3:, 0:, 4:1:-1] = b
  116. a[2:3:, 0:, 4:1:-1] = self.c
  117. z = a + ck
  118. return z
  119. class TensorGetItemByOneTensor(Cell):
  120. def __init__(self):
  121. super(TensorGetItemByOneTensor, self).__init__()
  122. self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
  123. def construct(self, x, index):
  124. ret = x[index] + self.const
  125. return ret
  126. class TensorGetItemByTwoTensors(Cell):
  127. def __init__(self):
  128. super(TensorGetItemByTwoTensors, self).__init__()
  129. self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
  130. def construct(self, x, index_0, index_1):
  131. ret = x[index_0, index_1] + self.const
  132. return ret
  133. class TensorGetItemByThreeTensors(Cell):
  134. def __init__(self):
  135. super(TensorGetItemByThreeTensors, self).__init__()
  136. self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
  137. def construct(self, x, index_0, index_1, index_2):
  138. ret = x[index_0, index_1, index_2] + self.const
  139. return ret
  140. class TensorGetItemByMixedTensors(Cell):
  141. def __init__(self):
  142. super(TensorGetItemByMixedTensors, self).__init__()
  143. def construct(self, x, index_0, index_1):
  144. ret = x[index_0, index_1, 0:6]
  145. return ret
  146. class TensorSetItemByOneTensorWithNumber(Cell):
  147. def __init__(self, value):
  148. super(TensorSetItemByOneTensorWithNumber, self).__init__()
  149. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  150. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  151. self.value = value
  152. def construct(self, index):
  153. self.param[index] = self.value
  154. ret = self.param + self.const
  155. return ret
  156. class TensorSetItemByOneTensorWithTensor(Cell):
  157. def __init__(self):
  158. super(TensorSetItemByOneTensorWithTensor, self).__init__()
  159. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  160. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  161. def construct(self, index, value):
  162. self.param[index] = value
  163. ret = self.param + self.const
  164. return ret
  165. class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
  166. def __init__(self, value):
  167. super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
  168. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  169. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  170. self.value = value
  171. def construct(self, index):
  172. self.param[index] = self.value
  173. ret = self.param + self.const
  174. return ret
  175. class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
  176. def __init__(self):
  177. super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
  178. self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
  179. self.param = Parameter(Tensor(np.arange(6*3*8).reshape((6, 3, 8)), mstype.float32), name="x")
  180. def construct(self, index, value_0, value_1, value_2):
  181. self.param[index] = (value_0, value_1, value_2)
  182. ret = self.param + self.const
  183. return ret
  184. class TensorSetItemByTensorsWithNumber(Cell):
  185. def __init__(self, value):
  186. super(TensorSetItemByTensorsWithNumber, self).__init__()
  187. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  188. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  189. self.value = value
  190. def construct(self, index_0, index_1, index_2):
  191. self.param[index_0, index_1, index_2] = self.value
  192. ret = self.param + self.const
  193. return ret
  194. class TensorSetItemByTensorsWithTensor(Cell):
  195. def __init__(self):
  196. super(TensorSetItemByTensorsWithTensor, self).__init__()
  197. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  198. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  199. def construct(self, index_0, index_1, index_2, value):
  200. self.param[index_0, index_1, index_2] = value
  201. ret = self.param + self.const
  202. return ret
  203. class TensorSetItemByTensorsWithTensorNumberError(Cell):
  204. def __init__(self):
  205. super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
  206. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  207. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  208. def construct(self, index_0, index_1, index_2, index_3, value):
  209. self.param[index_0, index_1, index_2, index_3] = value
  210. ret = self.param + self.const
  211. return ret
  212. class TensorSetItemByTensorsWithTupleOfNumber(Cell):
  213. def __init__(self, value):
  214. super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
  215. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  216. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  217. self.value = value
  218. def construct(self, index_0, index_1, index_2):
  219. self.param[index_0, index_1, index_2] = self.value
  220. ret = self.param + self.const
  221. return ret
  222. class TensorSetItemByTensorsWithTupleOfTensor(Cell):
  223. def __init__(self):
  224. super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
  225. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  226. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  227. def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
  228. self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
  229. ret = self.param + self.const
  230. return ret
  231. class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
  232. def __init__(self):
  233. super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
  234. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  235. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  236. def construct(self, index_0, index_1, index_2, value_0, value_1):
  237. self.param[index_0, index_1, index_2] = (value_0, value_1)
  238. ret = self.param + self.const
  239. return ret
  240. class TensorSetItemByMixedTensors(Cell):
  241. def __init__(self):
  242. super(TensorSetItemByMixedTensors, self).__init__()
  243. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  244. self.param = Parameter(Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.float32), name="x")
  245. self.value = 99.0
  246. def construct(self, index_0, index_1):
  247. self.param[index_0, index_1, 0:6] = self.value
  248. ret = self.param + self.const
  249. return ret
  250. def test_tensor_assign():
  251. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  252. net = TensorAssignWithSlice()
  253. net2 = TensorAssignWithSlice2()
  254. net_e1 = TensorAssignWithSliceError1()
  255. net_e2 = TensorAssignWithSliceError2()
  256. a = np.arange(60).reshape(3, 4, 5)
  257. ck = np.arange(60).reshape(3, 4, 5)
  258. b = Tensor([1], dtype=mstype.float32)
  259. Ta = Tensor(a, dtype=mstype.float32)
  260. Tck = Tensor(ck, dtype=mstype.float32)
  261. Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32)
  262. Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32)
  263. Tb = Tensor([1, 3], dtype=mstype.float32)
  264. Tc = Tensor([], dtype=mstype.float32)
  265. t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  266. tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  267. net(Ta, b, Tck)
  268. net2(t, b, tck)
  269. # Error for A[Slice] = Number
  270. # 1. A[Slice] = Number, Slice error
  271. with pytest.raises(IndexError):
  272. net_e2(t, 2)
  273. # Error for A[Slice] = U, U is a Tensor
  274. # 1. A[Slice] = U, u.size is error
  275. with pytest.raises(ValueError):
  276. net2(t, Tb, tck)
  277. # 2. A[Slice] = U, U is empty
  278. with pytest.raises(ValueError):
  279. net2(t, Tc, tck)
  280. # 3. A[Slice] = U, U.size error
  281. with pytest.raises(ValueError):
  282. net2(t, Tb, tck)
  283. # Error for A[Tuple(Slice...)] = Tensor
  284. # 1. A[Tuple(Slice...)] = U, U is empty
  285. with pytest.raises(ValueError):
  286. net(Ta, Tc, Tck)
  287. # 2. A[Tuple(Slice...)] = U, U.size error
  288. with pytest.raises(ValueError):
  289. net(Ta, Tb, Tck)
  290. # 3. A[Tuple(Slice...)] = U, Slice error
  291. with pytest.raises(IndexError):
  292. net_e1(Ta, b)
  293. # Error for A[Tuple(Slice...)] = Number
  294. # 1. A[Tuple(Slice...)] = Number, Slice error
  295. with pytest.raises(IndexError):
  296. net_e1(Ta, 2)
  297. net = TensorAssignWithInteger()
  298. # Error for A[Number] = scalar/Tensor
  299. # 1. A[Number] = U, U is a Tensor, u.size not match
  300. with pytest.raises(ValueError):
  301. net(Ta, Tb, Tck)
  302. with pytest.raises(ValueError):
  303. net(Ta, Tc, Tck)
  304. # 2. A[Number] = U, the number index error
  305. with pytest.raises(IndexError):
  306. net(Ta4d, b, Ta4d_ck)
  307. # Error for A[(n,m)] = scalar/Tensor
  308. # 1. A[(n,m)] = U, U is a tensor. u.size not match
  309. net = TensorAssignWithTupleInteger()
  310. with pytest.raises(ValueError):
  311. net(Ta, Tc, Tck)
  312. with pytest.raises(ValueError):
  313. net(Ta, Tb, Tck)
  314. # 2. A[(n,m)] = U, the number index error
  315. with pytest.raises(IndexError):
  316. net(Ta4d, b, Ta4d_ck)
  317. # Error for A[...] = U or A[1:, ...] = u
  318. # 1. A[...] = scalar/tensor
  319. net = TensorAssignWithEllipsis()
  320. net(Ta, Ta4d)
  321. with pytest.raises(ValueError):
  322. net(Ta, Tc)
  323. with pytest.raises(ValueError):
  324. net(Ta, Tb)
  325. # 2. A[::, 1:, ...] = scalar/tensor
  326. net = TensorAssignWithTupleEllipsis()
  327. net(Ta, b)
  328. with pytest.raises(ValueError):
  329. net(Ta, Tc)
  330. with pytest.raises(ValueError):
  331. net(Ta, Tb)
  332. class TensorAssignWithTupleEllipsis2(Cell):
  333. def __init__(self):
  334. super(TensorAssignWithTupleEllipsis2, self).__init__()
  335. def construct(self, a, b):
  336. a[1:, ..., ::] = b
  337. return a
  338. class TensorAssignWithTupleEllipsis(Cell):
  339. def __init__(self):
  340. super(TensorAssignWithTupleEllipsis, self).__init__()
  341. def construct(self, a, b):
  342. a[:2, ...] = 1
  343. a[1:, ...] = b
  344. return a
  345. class TensorAssignWithEllipsis(Cell):
  346. def __init__(self):
  347. super(TensorAssignWithEllipsis, self).__init__()
  348. def construct(self, a, b):
  349. a[...] = 1
  350. a[...] = b
  351. return a
  352. class TensorAssignWithInteger(Cell):
  353. def __init__(self):
  354. super(TensorAssignWithInteger, self).__init__()
  355. def construct(self, a, b, ck):
  356. a[1] = 1
  357. a[0] = b
  358. z = a + ck
  359. return z
  360. class TensorAssignWithTupleInteger(Cell):
  361. def __init__(self):
  362. super(TensorAssignWithTupleInteger, self).__init__()
  363. def construct(self, a, b, ck):
  364. a[(1)] = 1
  365. a[(1)] = b
  366. a[(1, 1)] = b
  367. a[(1, 1)] = 1
  368. z = a + ck
  369. return z
  370. class TensorAssignWithBoolTensorIndex(Cell):
  371. def __init__(self):
  372. super(TensorAssignWithBoolTensorIndex, self).__init__()
  373. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  374. self.u_scalar = 5
  375. def construct(self, a, b, c, u_tensor):
  376. a[c] = self.u_scalar
  377. a[b] = u_tensor
  378. z = a + self.t
  379. return z
  380. class TensorAssignWithBoolTensorIndexError(Cell):
  381. def __init__(self):
  382. super(TensorAssignWithBoolTensorIndexError, self).__init__()
  383. def construct(self, a, b, c, u_tensor):
  384. a[b][c] = u_tensor
  385. return a
  386. class TensorAssignWithBoolTensorIndex2(Cell):
  387. def __init__(self):
  388. super(TensorAssignWithBoolTensorIndex2, self).__init__()
  389. self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
  390. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  391. self.u_scalar = 5
  392. def construct(self, a, u_tensor):
  393. a[a > 8] = u_tensor
  394. a[a >= 6] = self.u_scalar
  395. a[a < 3] = self.u_scalar
  396. a[a <= 5] = u_tensor
  397. a[a == 5] = self.u_scalar
  398. z = a + self.t
  399. return z
  400. class TensorAssignWithBoolTensorIndex2Error(Cell):
  401. def __init__(self):
  402. super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
  403. def construct(self, a, u_tensor):
  404. a[a > 8][a > 5] = u_tensor
  405. return a
  406. a = np.arange(60).reshape(3, 4, 5)
  407. ck = np.arange(60).reshape(3, 4, 5)
  408. a4 = np.arange(60).reshape(3, 2, 2, 5)
  409. b = a > 5
  410. c = a < 3
  411. Ta = Tensor(a, dtype=mstype.float32)
  412. Tck = Tensor(ck, dtype=mstype.float32)
  413. Ta4 = Tensor(a4, dtype=mstype.float32)
  414. Tb = Tensor(b)
  415. Tc = Tensor(c)
  416. Td = Tensor([True, True])
  417. u_tensor = Tensor([1], dtype=mstype.float32)
  418. u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
  419. t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  420. tck_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  421. u_scalar = 5
  422. def test_tensor_assign_bool_index():
  423. net1 = TensorAssignWithBoolTensorIndex()
  424. net2 = TensorAssignWithBoolTensorIndex2()
  425. net1(Ta, Tb, Tc, u_tensor)
  426. net1(Ta, Tb, Tc, u_tensor)
  427. with pytest.raises(ValueError):
  428. net1(Ta, Td, Tc, u_tensor)
  429. with pytest.raises(TypeError):
  430. net1(Ta, u_tensor, Tc, u_tensor)
  431. with pytest.raises(ValueError):
  432. net1(Ta, Tb, Td, u_tensor)
  433. with pytest.raises(TypeError):
  434. net1(Ta, Tb, Ta, u_tensor)
  435. with pytest.raises(ValueError):
  436. net1(Ta, Tb, Tc, u_tensor_error)
  437. # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
  438. with pytest.raises(ValueError):
  439. net2(Ta, u_tensor_error)
  440. net3 = TensorAssignWithBoolTensorIndexError()
  441. with pytest.raises(AttributeError):
  442. net3(Ta, Tb, Tc, u_tensor)
  443. with pytest.raises(AttributeError):
  444. net3(Ta, Tb, Tc, u_scalar)
  445. net4 = TensorAssignWithBoolTensorIndex2Error()
  446. with pytest.raises(AttributeError):
  447. net4(Ta, u_tensor)
  448. with pytest.raises(AttributeError):
  449. net4(Ta, u_scalar)
  450. test_cases = [
  451. ('TensorAssignWithTupleEllipsis2', {
  452. 'block': TensorAssignWithTupleEllipsis2(),
  453. 'desc_inputs': [Ta4, u_tensor],
  454. }),
  455. ('TensorAssignWithTupleEllipsis', {
  456. 'block': TensorAssignWithTupleEllipsis(),
  457. 'desc_inputs': [Ta, u_tensor],
  458. }),
  459. ('TensorAssignWithEllipsis', {
  460. 'block': TensorAssignWithEllipsis(),
  461. 'desc_inputs': [Ta, u_tensor],
  462. }),
  463. ('TensorAssignWithTupleInteger', {
  464. 'block': TensorAssignWithTupleInteger(),
  465. 'desc_inputs': [Ta, u_tensor, Tck],
  466. }),
  467. ('TensorAssignWithInteger', {
  468. 'block': TensorAssignWithInteger(),
  469. 'desc_inputs': [Ta, u_tensor, Tck],
  470. }),
  471. ('TensorAssignWithSlice', {
  472. 'block': TensorAssignWithSlice(),
  473. 'desc_inputs': [Ta, u_tensor, Tck],
  474. }),
  475. ('TensorAssignWithSlice2', {
  476. 'block': TensorAssignWithSlice2(),
  477. 'desc_inputs': [t_1d, u_tensor, tck_1d],
  478. }),
  479. ('TensorAssignWithBoolTensorIndex', {
  480. 'block': TensorAssignWithBoolTensorIndex(),
  481. 'desc_inputs': [Ta, Tb, Tc, u_tensor],
  482. }),
  483. ('TensorAssignWithBoolTensorIndex2', {
  484. 'block': TensorAssignWithBoolTensorIndex2(),
  485. 'desc_inputs': [Ta, u_tensor],
  486. }),
  487. ('SlicePositive', {
  488. 'block': NetWorkSlicePositive(),
  489. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  490. }),
  491. ('SliceReduceDimension', {
  492. 'block': NetWorkReduceDimension(),
  493. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  494. }),
  495. ('SliceNegative', {
  496. 'block': NetWorkStepNegative(),
  497. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  498. }),
  499. ('SliceReduceToScalar', {
  500. 'block': NetWorkReduceToScalar(),
  501. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  502. }),
  503. ('TensorSliceEllipsis', {
  504. 'block': NetWorkSliceEllipsis(),
  505. 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
  506. }),
  507. ('TensorGetItemByOneTensor', {
  508. 'block': TensorGetItemByOneTensor(),
  509. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  510. Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
  511. }),
  512. ('TensorGetItemByTwoTensors', {
  513. 'block': TensorGetItemByTwoTensors(),
  514. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  515. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  516. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
  517. }),
  518. ('TensorGetItemByThreeTensors', {
  519. 'block': TensorGetItemByThreeTensors(),
  520. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  521. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  522. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  523. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  524. }),
  525. ('TensorSetItemByOneTensorWithNumber', {
  526. 'block': TensorSetItemByOneTensorWithNumber(value=0.0),
  527. 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
  528. }),
  529. ('TensorSetItemByOneTensorWithTensor', {
  530. 'block': TensorSetItemByOneTensorWithTensor(),
  531. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  532. Tensor(np.zeros((4, 7, 8)), mstype.float32)],
  533. }),
  534. ('TensorSetItemByOneTensorWithTupleOfNumber', {
  535. 'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
  536. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  537. }),
  538. ('TensorSetItemByOneTensorWithTupleOfTensor', {
  539. 'block': TensorSetItemByOneTensorWithTupleOfTensor(),
  540. 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
  541. Tensor(np.zeros((8,), np.float32)),
  542. Tensor(np.ones((8,), np.float32)),
  543. Tensor(np.ones((8,), np.float32) * 2)],
  544. }),
  545. ('TensorSetItemByTensorsWithNumber', {
  546. 'block': TensorSetItemByTensorsWithNumber(value=0.0),
  547. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  548. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  549. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  550. }),
  551. ('TensorSetItemByTensorsWithTensor', {
  552. 'block': TensorSetItemByTensorsWithTensor(),
  553. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  554. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  555. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  556. Tensor(np.zeros((4, 5)), mstype.float32)],
  557. }),
  558. ('TensorSetItemByTensorsWithTupleOfNumber', {
  559. 'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
  560. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  561. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  562. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  563. }),
  564. ('TensorSetItemByTensorsWithTupleOfTensor', {
  565. 'block': TensorSetItemByTensorsWithTupleOfTensor(),
  566. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  567. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  568. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  569. Tensor(np.zeros((4, 5)), mstype.float32),
  570. Tensor(np.ones((4, 5)), mstype.float32),
  571. Tensor(np.ones((4, 5)) * 2, mstype.float32)],
  572. })
  573. ]
  574. raise_error_set = [
  575. ('TensorGetItemByOneTensorDtypeError', {
  576. 'block': (TensorGetItemByOneTensor(), {'exception': TypeError}),
  577. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  578. Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
  579. }),
  580. ('TensorGetItemByTwoTensorsShapeError', {
  581. 'block': (TensorGetItemByTwoTensors(), {'exception': ValueError}),
  582. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  583. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  584. Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
  585. }),
  586. ('TensorGetItemByTwoTensorsDtypeError', {
  587. 'block': (TensorGetItemByTwoTensors(), {'exception': TypeError}),
  588. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  589. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  590. Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
  591. }),
  592. ('TensorGetItemByThreeTensorsShapeError', {
  593. 'block': (TensorGetItemByThreeTensors(), {'exception': ValueError}),
  594. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  595. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  596. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
  597. Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
  598. }),
  599. ('TensorGetItemByThreeTensorsDtypeError', {
  600. 'block': (TensorGetItemByThreeTensors(), {'exception': TypeError}),
  601. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  602. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  603. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64),
  604. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  605. }),
  606. ('TensorGetItemByMixedTensors', {
  607. 'block': (TensorGetItemByMixedTensors(), {'exception': IndexError}),
  608. 'desc_inputs': [Tensor(np.arange(6*7*8).reshape((6, 7, 8)), mstype.int32),
  609. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  610. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int64)],
  611. }),
  612. ('TensorSetItemByOneTensorWithNumberTypeError', {
  613. 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
  614. 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
  615. }),
  616. ('TensorSetItemByOneTensorWithTensorShapeError', {
  617. 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
  618. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  619. Tensor(np.zeros((6, 7, 8)), mstype.float32)],
  620. }),
  621. ('TensorSetItemByOneTensorWithTensorDtypeError', {
  622. 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
  623. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  624. Tensor(np.zeros((6, 7, 8)), mstype.int32)],
  625. }),
  626. ('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
  627. 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
  628. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  629. }),
  630. ('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
  631. 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
  632. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  633. }),
  634. ('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
  635. 'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
  636. 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
  637. Tensor(np.zeros((8,), np.int32)),
  638. Tensor(np.ones((8,), np.int32)),
  639. Tensor(np.ones((8,), np.float32) * 2)],
  640. }),
  641. ('TensorSetItemByTensorsWithNumberTypeError', {
  642. 'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
  643. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  644. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  645. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  646. }),
  647. ('TensorSetItemByTensorsWithTensorShapeError', {
  648. 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
  649. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  650. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  651. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  652. Tensor(np.zeros((2, 5)), mstype.float32)],
  653. }),
  654. ('TensorSetItemByTensorsWithTensorTypeError', {
  655. 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
  656. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  657. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  658. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  659. Tensor(np.zeros((4, 5)), mstype.int32)],
  660. }),
  661. ('TensorSetItemByTensorsWithTensorNumberError', {
  662. 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
  663. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  664. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  665. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  666. Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
  667. Tensor(np.zeros((2, 5)), mstype.float32)],
  668. }),
  669. ('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
  670. 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0, 1, 2, 3, 4)), {'exception': TypeError}),
  671. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  672. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  673. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  674. }),
  675. ('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
  676. 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
  677. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  678. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  679. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  680. }),
  681. ('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
  682. 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
  683. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  684. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  685. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  686. Tensor(np.zeros((4, 5)), mstype.float32),
  687. Tensor(np.ones((4, 5)), mstype.float32)],
  688. }),
  689. ('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
  690. 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
  691. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  692. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  693. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  694. Tensor(np.zeros((4, 5)), mstype.float32),
  695. Tensor(np.ones((4, 5)), mstype.int32),
  696. Tensor(np.ones((4, 5)) * 2, mstype.int32)],
  697. }),
  698. ('TensorSetItemByMixedTensors', {
  699. 'block': (TensorSetItemByMixedTensors(), {'exception': IndexError}),
  700. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  701. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
  702. })
  703. ]
  704. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  705. def test_exec():
  706. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  707. return test_cases
  708. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
  709. def test_check_exception():
  710. return raise_error_set
  711. def test_tensor_slice_reduce_out_of_bounds_neg():
  712. class NetWork(Cell):
  713. def __init__(self):
  714. super(NetWork, self).__init__()
  715. self.tensor_ret = Tensor(np.array(9, np.int32))
  716. def construct(self, tensor):
  717. ret = tensor[-7, 3, 4]
  718. return ret
  719. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  720. net = NetWork()
  721. with pytest.raises(ValueError) as ex:
  722. net(input_tensor)
  723. assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(
  724. ex.value)
  725. def test_tensor_slice_reduce_out_of_bounds_positive():
  726. class NetWork(Cell):
  727. def __init__(self):
  728. super(NetWork, self).__init__()
  729. self.tensor_ret = Tensor(np.array(9, np.int32))
  730. def construct(self, tensor):
  731. ret = tensor[6, 3, 4]
  732. return ret
  733. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  734. net = NetWork()
  735. with pytest.raises(ValueError) as ex:
  736. net(input_tensor)
  737. assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)