You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_slice.py 52 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_slice """
  16. import numpy as np
  17. import pytest
  18. from mindspore import Tensor, Parameter
  19. from mindspore import context
  20. from mindspore import dtype as mstype
  21. from mindspore.nn import Cell
  22. from ....mindspore_test_framework.mindspore_test import mindspore_test
  23. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  24. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
  25. pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
  26. class NetWorkSlicePositive(Cell):
  27. def __init__(self):
  28. super(NetWorkSlicePositive, self).__init__()
  29. self.tensor_ret0 = Tensor(np.ones([1, 2, 2], np.int32))
  30. self.tensor_ret1 = Tensor(np.ones([4, 7, 4], np.int32))
  31. self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32))
  32. self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32))
  33. def construct(self, tensor):
  34. ret0 = tensor[3:4:3, 1:5:2, 3:6:2] + self.tensor_ret0
  35. ret1 = tensor[-6:4:1, 7:-8:-1, ::3] + self.tensor_ret1
  36. ret2 = tensor[::, ::, ::] + self.tensor_ret2
  37. ret3 = tensor[::2] + self.tensor_ret3
  38. return ret0, ret1, ret2, ret3
  39. class NetWorkSliceEllipsis(Cell):
  40. def __init__(self):
  41. super(NetWorkSliceEllipsis, self).__init__()
  42. self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
  43. self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
  44. self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
  45. def construct(self, tensor):
  46. ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
  47. ret1 = tensor[...] + self.tensor_ret1
  48. ret2 = tensor[None] + self.tensor_ret2
  49. ret3 = tensor[True] + self.tensor_ret2
  50. return ret0, ret1, ret2, ret3
  51. class NetWorkReduceDimension(Cell):
  52. def __init__(self):
  53. super(NetWorkReduceDimension, self).__init__()
  54. self.tensor_ret0 = Tensor(np.ones([2, 4, 1], np.int32))
  55. self.tensor_ret1 = Tensor(np.ones([3, 4], np.int32))
  56. self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32))
  57. self.tensor_ret3 = Tensor(np.array(8, np.int32))
  58. self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32))
  59. def construct(self, tensor):
  60. ret0 = tensor[0:6:3, 1:5:1, 3:5:2] + self.tensor_ret0
  61. ret1 = tensor[::2, 1, ::3] + self.tensor_ret1
  62. ret2 = tensor[::, ::, 0] + self.tensor_ret2
  63. ret3 = tensor[3, 2, 5] + self.tensor_ret3
  64. ret4 = tensor[1] + self.tensor_ret4
  65. return ret0, ret1, ret2, ret3, ret4
  66. class NetWorkStepNegative(Cell):
  67. def __init__(self):
  68. super(NetWorkStepNegative, self).__init__()
  69. self.tensor_ret = Tensor(np.ones([6, 5, 10], np.int32))
  70. def construct(self, tensor):
  71. ret = tensor[::1, -5::, ::-1] + self.tensor_ret
  72. return ret
  73. class NetWorkReduceToScalar(Cell):
  74. def __init__(self):
  75. super(NetWorkReduceToScalar, self).__init__()
  76. self.tensor_ret = Tensor(np.array(9, np.int32))
  77. def construct(self, tensor):
  78. ret = tensor[2, 3, 4] + self.tensor_ret
  79. return ret
  80. class TensorAssignWithSliceError1(Cell):
  81. def __init__(self):
  82. super(TensorAssignWithSliceError1, self).__init__()
  83. def construct(self, a, b):
  84. a[1:3:-1, ::] = b
  85. return a
  86. class TensorAssignWithSliceError2(Cell):
  87. def __init__(self):
  88. super(TensorAssignWithSliceError2, self).__init__()
  89. def construct(self, a, b):
  90. a[1:3:-1] = b
  91. return a
  92. class TensorAssignWithSlice2(Cell):
  93. def __init__(self):
  94. super(TensorAssignWithSlice2, self).__init__()
  95. def construct(self, a, b, ck):
  96. a[1:5] = b
  97. a[3:4] = 5
  98. a[-1:1:-1] = b
  99. a[-1:3:-1] = 5
  100. a[::] = b
  101. a[::] = 9
  102. z = a + ck
  103. return z
  104. class TensorAssignWithSlice(Cell):
  105. def __init__(self):
  106. super(TensorAssignWithSlice, self).__init__()
  107. self.c = 2.0
  108. def construct(self, a, b, ck):
  109. a[1:3, ::] = b
  110. a[2:3:, 3:] = b
  111. a[::] = b
  112. a[::] = self.c
  113. a[::, ::] = b
  114. a[::, ::] = self.c
  115. a[2:3:, 0:, 4:1:-1] = b
  116. a[2:3:, 0:, 4:1:-1] = self.c
  117. z = a + ck
  118. return z
  119. class TensorGetItemByOneTensor(Cell):
  120. def __init__(self):
  121. super(TensorGetItemByOneTensor, self).__init__()
  122. self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
  123. def construct(self, x, index):
  124. ret = x[index] + self.const
  125. return ret
  126. class TensorGetItemByTwoTensors(Cell):
  127. def __init__(self):
  128. super(TensorGetItemByTwoTensors, self).__init__()
  129. self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
  130. def construct(self, x, index_0, index_1):
  131. ret = x[index_0, index_1] + self.const
  132. return ret
  133. class TensorGetItemByThreeTensors(Cell):
  134. def __init__(self):
  135. super(TensorGetItemByThreeTensors, self).__init__()
  136. self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
  137. def construct(self, x, index_0, index_1, index_2):
  138. ret = x[index_0, index_1, index_2] + self.const
  139. return ret
  140. class TensorGetItemByMixedTensors_0(Cell):
  141. def __init__(self):
  142. super(TensorGetItemByMixedTensors_0, self).__init__()
  143. self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
  144. def construct(self, tensor, index_0, index_1):
  145. ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
  146. return ret
  147. class TensorGetItemByMixedTensors_1(Cell):
  148. def __init__(self):
  149. super(TensorGetItemByMixedTensors_1, self).__init__()
  150. self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
  151. def construct(self, tensor, index_0, index_1):
  152. ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
  153. return ret
  154. class TensorGetItemByMixedTensors_2(Cell):
  155. def __init__(self):
  156. super(TensorGetItemByMixedTensors_2, self).__init__()
  157. self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
  158. def construct(self, tensor, index_0, index_1):
  159. ret = tensor[0, index_0, index_1, ..., 3] + self.const
  160. return ret
  161. class TensorGetItemByMixedTensors_3(Cell):
  162. def __init__(self):
  163. super(TensorGetItemByMixedTensors_3, self).__init__()
  164. self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
  165. def construct(self, tensor, index_0, index_1):
  166. ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
  167. return ret
  168. class TensorGetItemByMixedTensors_4(Cell):
  169. def __init__(self):
  170. super(TensorGetItemByMixedTensors_4, self).__init__()
  171. self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
  172. def construct(self, tensor, index_0, index_1, index_2):
  173. ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
  174. return ret
  175. class TensorGetItemByMixedTensors_5(Cell):
  176. def __init__(self):
  177. super(TensorGetItemByMixedTensors_5, self).__init__()
  178. self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
  179. def construct(self, tensor, index_0, index_1, index_2):
  180. ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
  181. return ret
  182. class TensorGetItemByMixedTensors_6(Cell):
  183. def __init__(self):
  184. super(TensorGetItemByMixedTensors_6, self).__init__()
  185. self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
  186. def construct(self, tensor, index_0, index_1, index_2):
  187. ret = tensor[..., index_0, index_1, index_2, 3] + self.const
  188. return ret
  189. class TensorSetItemByMixedTensors_0(Cell):
  190. def __init__(self, value):
  191. super(TensorSetItemByMixedTensors_0, self).__init__()
  192. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
  193. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
  194. mstype.float32),
  195. name="x")
  196. self.value = value
  197. def construct(self, index_0, index_1, index_2):
  198. self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
  199. ret = self.param + self.const
  200. return ret
  201. class TensorSetItemByMixedTensors_1(Cell):
  202. def __init__(self, value):
  203. super(TensorSetItemByMixedTensors_1, self).__init__()
  204. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
  205. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  206. name="x")
  207. self.value = value
  208. def construct(self, index_0, index_1, index_2):
  209. self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
  210. ret = self.param + self.const
  211. return ret
  212. class TensorSetItemByMixedTensors_2(Cell):
  213. def __init__(self, value):
  214. super(TensorSetItemByMixedTensors_2, self).__init__()
  215. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
  216. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
  217. name="x")
  218. self.value = value
  219. def construct(self, index_0, index_1, index_2):
  220. self.param[..., index_0, index_1, index_2, 3] = self.value
  221. ret = self.param + self.const
  222. return ret
  223. class TensorGetItemByMixedTensorsTypeError(Cell):
  224. def __init__(self):
  225. super(TensorGetItemByMixedTensorsTypeError, self).__init__()
  226. def construct(self, x, index_0, index_1):
  227. ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
  228. return ret
  229. class TensorGetItemByMixedTensorsNumberError(Cell):
  230. def __init__(self):
  231. super(TensorGetItemByMixedTensorsNumberError, self).__init__()
  232. def construct(self, x, index_0, index_1):
  233. ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
  234. return ret
  235. class TensorSetItemByOneTensorWithNumber(Cell):
  236. def __init__(self, value):
  237. super(TensorSetItemByOneTensorWithNumber, self).__init__()
  238. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  239. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  240. self.value = value
  241. def construct(self, index):
  242. self.param[index] = self.value
  243. ret = self.param + self.const
  244. return ret
  245. class TensorSetItemByOneTensorWithTensor(Cell):
  246. def __init__(self):
  247. super(TensorSetItemByOneTensorWithTensor, self).__init__()
  248. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  249. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  250. def construct(self, index, value):
  251. self.param[index] = value
  252. ret = self.param + self.const
  253. return ret
  254. class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
  255. def __init__(self, value):
  256. super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
  257. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  258. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  259. self.value = value
  260. def construct(self, index):
  261. self.param[index] = self.value
  262. ret = self.param + self.const
  263. return ret
  264. class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
  265. def __init__(self):
  266. super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
  267. self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
  268. self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")
  269. def construct(self, index, value_0, value_1, value_2):
  270. self.param[index] = (value_0, value_1, value_2)
  271. ret = self.param + self.const
  272. return ret
  273. class TensorSetItemByTensorsWithNumber(Cell):
  274. def __init__(self, value):
  275. super(TensorSetItemByTensorsWithNumber, self).__init__()
  276. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  277. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  278. self.value = value
  279. def construct(self, index_0, index_1, index_2):
  280. self.param[index_0, index_1, index_2] = self.value
  281. ret = self.param + self.const
  282. return ret
  283. class TensorSetItemByTensorsWithTensor(Cell):
  284. def __init__(self):
  285. super(TensorSetItemByTensorsWithTensor, self).__init__()
  286. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  287. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  288. def construct(self, index_0, index_1, index_2, value):
  289. self.param[index_0, index_1, index_2] = value
  290. ret = self.param + self.const
  291. return ret
  292. class TensorSetItemByTensorsWithTensorNumberError(Cell):
  293. def __init__(self):
  294. super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
  295. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  296. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  297. def construct(self, index_0, index_1, index_2, index_3, value):
  298. self.param[index_0, index_1, index_2, index_3] = value
  299. ret = self.param + self.const
  300. return ret
  301. class TensorSetItemByTensorsWithTupleOfNumber(Cell):
  302. def __init__(self, value):
  303. super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
  304. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  305. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  306. self.value = value
  307. def construct(self, index_0, index_1, index_2):
  308. self.param[index_0, index_1, index_2] = self.value
  309. ret = self.param + self.const
  310. return ret
  311. class TensorSetItemByTensorsWithTupleOfTensor(Cell):
  312. def __init__(self):
  313. super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
  314. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  315. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  316. def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
  317. self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
  318. ret = self.param + self.const
  319. return ret
  320. class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
  321. def __init__(self):
  322. super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
  323. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  324. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  325. def construct(self, index_0, index_1, index_2, value_0, value_1):
  326. self.param[index_0, index_1, index_2] = (value_0, value_1)
  327. ret = self.param + self.const
  328. return ret
  329. class TensorSetItemByMixedTensors(Cell):
  330. def __init__(self):
  331. super(TensorSetItemByMixedTensors, self).__init__()
  332. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  333. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  334. self.value = 99.0
  335. def construct(self, index_0, index_1):
  336. self.param[index_0, index_1, 0:6] = self.value
  337. ret = self.param + self.const
  338. return ret
  339. def test_tensor_assign():
  340. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  341. net = TensorAssignWithSlice()
  342. net2 = TensorAssignWithSlice2()
  343. net_e1 = TensorAssignWithSliceError1()
  344. net_e2 = TensorAssignWithSliceError2()
  345. a = np.arange(60).reshape(3, 4, 5)
  346. ck = np.arange(60).reshape(3, 4, 5)
  347. b = Tensor([1], dtype=mstype.float32)
  348. Ta = Tensor(a, dtype=mstype.float32)
  349. Tck = Tensor(ck, dtype=mstype.float32)
  350. Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32)
  351. Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32)
  352. Tb = Tensor([1, 3], dtype=mstype.float32)
  353. Tc = Tensor([], dtype=mstype.float32)
  354. t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  355. tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  356. net(Ta, b, Tck)
  357. net2(t, b, tck)
  358. # Error for A[Slice] = Number
  359. # 1. A[Slice] = Number, 0 in shape
  360. with pytest.raises(ValueError):
  361. net_e2(t, Tensor(2, mstype.int32))
  362. # Error for A[Slice] = U, U is a Tensor
  363. # 1. A[Slice] = U, u.size is error
  364. with pytest.raises(ValueError):
  365. net2(t, Tb, tck)
  366. # 2. A[Slice] = U, U is empty
  367. with pytest.raises(ValueError):
  368. net2(t, Tc, tck)
  369. # 3. A[Slice] = U, U.size error
  370. with pytest.raises(ValueError):
  371. net2(t, Tb, tck)
  372. # Error for A[Tuple(Slice...)] = Tensor
  373. # 1. A[Tuple(Slice...)] = U, U is empty
  374. with pytest.raises(ValueError):
  375. net(Ta, Tc, Tck)
  376. # 2. A[Tuple(Slice...)] = U, U.size error
  377. with pytest.raises(ValueError):
  378. net(Ta, Tb, Tck)
  379. # 3. A[Tuple(Slice...)] = U, Slice error
  380. with pytest.raises(IndexError):
  381. net_e1(Ta, b)
  382. # Error for A[Tuple(Slice...)] = Number
  383. # 1. A[Tuple(Slice...)] = Number, Slice error
  384. with pytest.raises(IndexError):
  385. net_e1(Ta, Tensor(2, mstype.int32))
  386. net = TensorAssignWithInteger()
  387. # Error for A[Number] = scalar/Tensor
  388. # 1. A[Number] = U, U is a Tensor, u.size not match
  389. with pytest.raises(ValueError):
  390. net(Ta, Tb, Tck)
  391. with pytest.raises(ValueError):
  392. net(Ta, Tc, Tck)
  393. # 2. A[Number] = U, the number index error
  394. with pytest.raises(IndexError):
  395. net(Ta4d, b, Ta4d_ck)
  396. # Error for A[(n,m)] = scalar/Tensor
  397. # 1. A[(n,m)] = U, U is a tensor. u.size not match
  398. net = TensorAssignWithTupleInteger()
  399. with pytest.raises(ValueError):
  400. net(Ta, Tc, Tck)
  401. with pytest.raises(ValueError):
  402. net(Ta, Tb, Tck)
  403. # 2. A[(n,m)] = U, the number index error
  404. with pytest.raises(IndexError):
  405. net(Ta4d, b, Ta4d_ck)
  406. # Error for A[...] = U or A[1:, ...] = u
  407. # 1. A[...] = scalar/tensor
  408. net = TensorAssignWithEllipsis()
  409. net(Ta, Ta4d)
  410. with pytest.raises(ValueError):
  411. net(Ta, Tc)
  412. with pytest.raises(ValueError):
  413. net(Ta, Tb)
  414. # 2. A[::, 1:, ...] = scalar/tensor
  415. net = TensorAssignWithTupleEllipsis()
  416. net(Ta, b)
  417. Tc = Tensor(1, mstype.float32)
  418. net(Ta, Tc)
  419. with pytest.raises(ValueError):
  420. net(Ta, Tb)
  421. class TensorAssignWithTupleEllipsis2(Cell):
  422. def __init__(self):
  423. super(TensorAssignWithTupleEllipsis2, self).__init__()
  424. def construct(self, a, b):
  425. a[1:, ..., ::] = b
  426. return a
  427. class TensorAssignWithTupleEllipsis(Cell):
  428. def __init__(self):
  429. super(TensorAssignWithTupleEllipsis, self).__init__()
  430. def construct(self, a, b):
  431. a[:2, ...] = 1.0
  432. a[1:, ...] = b
  433. return a
  434. class TensorAssignWithEllipsis(Cell):
  435. def __init__(self):
  436. super(TensorAssignWithEllipsis, self).__init__()
  437. def construct(self, a, b):
  438. a[...] = 1
  439. a[...] = b
  440. return a
  441. class TensorAssignWithInteger(Cell):
  442. def __init__(self):
  443. super(TensorAssignWithInteger, self).__init__()
  444. def construct(self, a, b, ck):
  445. a[1] = 1
  446. a[0] = b
  447. z = a + ck
  448. return z
  449. class TensorAssignWithTupleInteger(Cell):
  450. def __init__(self):
  451. super(TensorAssignWithTupleInteger, self).__init__()
  452. def construct(self, a, b, ck):
  453. a[(1)] = 1.0
  454. a[(1)] = b
  455. a[(1, 1)] = b
  456. a[(1, 1)] = 1.0
  457. z = a + ck
  458. return z
  459. class TensorAssignWithBoolTensorIndex(Cell):
  460. def __init__(self):
  461. super(TensorAssignWithBoolTensorIndex, self).__init__()
  462. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  463. self.u_scalar = 5
  464. def construct(self, a, b, c, u_tensor):
  465. a[c] = self.u_scalar
  466. a[b] = u_tensor
  467. z = a + self.t
  468. return z
  469. class TensorAssignWithBoolTensorIndexError(Cell):
  470. def __init__(self):
  471. super(TensorAssignWithBoolTensorIndexError, self).__init__()
  472. def construct(self, a, b, c, u_tensor):
  473. a[b][c] = u_tensor
  474. return a
  475. class TensorAssignWithBoolTensorIndex2(Cell):
  476. def __init__(self):
  477. super(TensorAssignWithBoolTensorIndex2, self).__init__()
  478. self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
  479. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  480. self.u_scalar = 5
  481. def construct(self, a, u_tensor):
  482. a[a > 8] = u_tensor
  483. a[a >= 6] = self.u_scalar
  484. a[a < 3] = self.u_scalar
  485. a[a <= 5] = u_tensor
  486. a[a == 5] = self.u_scalar
  487. z = a + self.t
  488. return z
  489. class TensorAssignWithBoolTensorIndex2Error(Cell):
  490. def __init__(self):
  491. super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
  492. def construct(self, a, u_tensor):
  493. a[a > 8][a > 5] = u_tensor
  494. return a
  495. a = np.arange(60).reshape(3, 4, 5)
  496. ck = np.arange(60).reshape(3, 4, 5)
  497. a4 = np.arange(60).reshape(3, 2, 2, 5)
  498. b = a > 5
  499. c = a < 3
  500. Ta = Tensor(a, dtype=mstype.float32)
  501. Tck = Tensor(ck, dtype=mstype.float32)
  502. Ta4 = Tensor(a4, dtype=mstype.float32)
  503. Tb = Tensor(b)
  504. Tc = Tensor(c)
  505. Td = Tensor([True, True])
  506. u_tensor = Tensor([1], dtype=mstype.float32)
  507. u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
  508. t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  509. tck_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  510. u_scalar = 5
  511. def test_tensor_assign_bool_index():
  512. net1 = TensorAssignWithBoolTensorIndex()
  513. net2 = TensorAssignWithBoolTensorIndex2()
  514. net1(Ta, Tb, Tc, u_tensor)
  515. net1(Ta, Tb, Tc, u_tensor)
  516. with pytest.raises(ValueError):
  517. net1(Ta, Td, Tc, u_tensor)
  518. with pytest.raises(IndexError):
  519. net1(Ta, u_tensor, Tc, u_tensor)
  520. with pytest.raises(ValueError):
  521. net1(Ta, Tb, Td, u_tensor)
  522. with pytest.raises(IndexError):
  523. net1(Ta, Tb, Ta, u_tensor)
  524. with pytest.raises(ValueError):
  525. net1(Ta, Tb, Tc, u_tensor_error)
  526. # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
  527. with pytest.raises(ValueError):
  528. net2(Ta, u_tensor_error)
  529. net3 = TensorAssignWithBoolTensorIndexError()
  530. with pytest.raises(IndexError):
  531. net3(Ta, Tb, Tc, u_tensor)
  532. with pytest.raises(IndexError):
  533. net3(Ta, Tb, Tc, Tensor(u_scalar, mstype.int32))
  534. net4 = TensorAssignWithBoolTensorIndex2Error()
  535. with pytest.raises(IndexError):
  536. net4(Ta, u_tensor)
  537. with pytest.raises(IndexError):
  538. net4(Ta, Tensor(u_scalar, mstype.int32))
  539. test_cases = [
  540. ('TensorAssignWithTupleEllipsis2', {
  541. 'block': TensorAssignWithTupleEllipsis2(),
  542. 'desc_inputs': [Ta4, u_tensor],
  543. }),
  544. ('TensorAssignWithTupleEllipsis', {
  545. 'block': TensorAssignWithTupleEllipsis(),
  546. 'desc_inputs': [Ta, u_tensor],
  547. }),
  548. ('TensorAssignWithEllipsis', {
  549. 'block': TensorAssignWithEllipsis(),
  550. 'desc_inputs': [Ta, u_tensor],
  551. }),
  552. ('TensorAssignWithTupleInteger', {
  553. 'block': TensorAssignWithTupleInteger(),
  554. 'desc_inputs': [Ta, u_tensor, Tck],
  555. }),
  556. ('TensorAssignWithInteger', {
  557. 'block': TensorAssignWithInteger(),
  558. 'desc_inputs': [Ta, u_tensor, Tck],
  559. }),
  560. ('TensorAssignWithSlice', {
  561. 'block': TensorAssignWithSlice(),
  562. 'desc_inputs': [Ta, u_tensor, Tck],
  563. }),
  564. ('TensorAssignWithSlice2', {
  565. 'block': TensorAssignWithSlice2(),
  566. 'desc_inputs': [t_1d, u_tensor, tck_1d],
  567. }),
  568. ('TensorAssignWithBoolTensorIndex', {
  569. 'block': TensorAssignWithBoolTensorIndex(),
  570. 'desc_inputs': [Ta, Tb, Tc, u_tensor],
  571. }),
  572. ('TensorAssignWithBoolTensorIndex2', {
  573. 'block': TensorAssignWithBoolTensorIndex2(),
  574. 'desc_inputs': [Ta, u_tensor],
  575. }),
  576. ('SlicePositive', {
  577. 'block': NetWorkSlicePositive(),
  578. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  579. }),
  580. ('SliceReduceDimension', {
  581. 'block': NetWorkReduceDimension(),
  582. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  583. }),
  584. ('SliceNegative', {
  585. 'block': NetWorkStepNegative(),
  586. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  587. }),
  588. ('SliceReduceToScalar', {
  589. 'block': NetWorkReduceToScalar(),
  590. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  591. }),
  592. ('TensorSliceEllipsis', {
  593. 'block': NetWorkSliceEllipsis(),
  594. 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
  595. }),
  596. ('TensorGetItemByOneTensor', {
  597. 'block': TensorGetItemByOneTensor(),
  598. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  599. Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
  600. }),
  601. ('TensorGetItemByTwoTensors', {
  602. 'block': TensorGetItemByTwoTensors(),
  603. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  604. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  605. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
  606. }),
  607. ('TensorGetItemByThreeTensors', {
  608. 'block': TensorGetItemByThreeTensors(),
  609. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  610. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  611. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  612. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  613. }),
  614. ('TensorGetItemByMixedTensors_0', {
  615. 'block': TensorGetItemByMixedTensors_0(),
  616. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  617. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  618. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  619. }),
  620. ('TensorGetItemByMixedTensors_1', {
  621. 'block': TensorGetItemByMixedTensors_1(),
  622. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  623. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  624. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  625. }),
  626. ('TensorGetItemByMixedTensors_2', {
  627. 'block': TensorGetItemByMixedTensors_2(),
  628. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  629. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  630. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  631. }),
  632. ('TensorGetItemByMixedTensors_3', {
  633. 'block': TensorGetItemByMixedTensors_3(),
  634. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  635. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  636. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  637. }),
  638. ('TensorGetItemByMixedTensors_4', {
  639. 'block': TensorGetItemByMixedTensors_4(),
  640. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32),
  641. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  642. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  643. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  644. }),
  645. ('TensorGetItemByMixedTensors_5', {
  646. 'block': TensorGetItemByMixedTensors_5(),
  647. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  648. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  649. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  650. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  651. }),
  652. ('TensorGetItemByMixedTensors_6', {
  653. 'block': TensorGetItemByMixedTensors_6(),
  654. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  655. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  656. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  657. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  658. }),
  659. ('TensorSetItemByOneTensorWithNumber', {
  660. 'block': TensorSetItemByOneTensorWithNumber(value=0.0),
  661. 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
  662. }),
  663. ('TensorSetItemByOneTensorWithTensor', {
  664. 'block': TensorSetItemByOneTensorWithTensor(),
  665. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  666. Tensor(np.zeros((4, 7, 8)), mstype.float32)],
  667. }),
  668. ('TensorSetItemByOneTensorWithTupleOfNumber', {
  669. 'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
  670. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  671. }),
  672. ('TensorSetItemByOneTensorWithTupleOfTensor', {
  673. 'block': TensorSetItemByOneTensorWithTupleOfTensor(),
  674. 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
  675. Tensor(np.zeros((8,), np.float32)),
  676. Tensor(np.ones((8,), np.float32)),
  677. Tensor(np.ones((8,), np.float32) * 2)],
  678. }),
  679. ('TensorSetItemByTensorsWithNumber', {
  680. 'block': TensorSetItemByTensorsWithNumber(value=0.0),
  681. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  682. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  683. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  684. }),
  685. ('TensorSetItemByTensorsWithTensor', {
  686. 'block': TensorSetItemByTensorsWithTensor(),
  687. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  688. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  689. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  690. Tensor(np.zeros((4, 5)), mstype.float32)],
  691. }),
  692. ('TensorSetItemByTensorsWithTupleOfNumber', {
  693. 'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
  694. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  695. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  696. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  697. }),
  698. ('TensorSetItemByTensorsWithTupleOfTensor', {
  699. 'block': TensorSetItemByTensorsWithTupleOfTensor(),
  700. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  701. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  702. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  703. Tensor(np.zeros((4, 5)), mstype.float32),
  704. Tensor(np.ones((4, 5)), mstype.float32),
  705. Tensor(np.ones((4, 5)) * 2, mstype.float32)],
  706. }),
  707. ('TensorSetItemByMixedTensorsWithNumber_0', {
  708. 'block': TensorSetItemByMixedTensors_0(value=88.0),
  709. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  710. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  711. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  712. }),
  713. ('TensorSetItemByMixedTensorsWithTensor_0', {
  714. 'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))),
  715. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  716. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  717. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  718. }),
  719. ('TensorGetItemByMixedTensorsWithTupleOfNumber_0', {
  720. 'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)),
  721. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  722. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  723. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  724. }),
  725. ('TensorGetItemByMixedTensorsWithTupleOfTensor_0', {
  726. 'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)),
  727. Tensor(np.zeros((4, 5, 3, 9), np.float32)),
  728. Tensor(np.ones((4, 5, 3, 9), np.float32)))),
  729. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  730. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  731. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  732. }),
  733. ('TensorSetItemByMixedTensorsWithNumber_1', {
  734. 'block': TensorSetItemByMixedTensors_1(value=88.0),
  735. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  736. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  737. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  738. }),
  739. ('TensorSetItemByMixedTensorsWithTensor_1', {
  740. 'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
  741. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  742. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  743. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  744. }),
  745. ('TensorGetItemByMixedTensorsWithTupleOfNumber_1', {
  746. 'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
  747. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  748. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  749. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  750. }),
  751. ('TensorGetItemByMixedTensorsWithTupleOfTensor_1', {
  752. 'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
  753. Tensor(np.zeros((5, 2, 6), np.float32)),
  754. Tensor(np.ones((5, 2, 6), np.float32)),
  755. Tensor(np.ones((5, 2, 6), np.float32)))),
  756. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  757. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  758. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  759. }),
  760. ('TensorSetItemByMixedTensorsWithNumber_2', {
  761. 'block': TensorSetItemByMixedTensors_2(value=88.0),
  762. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  763. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  764. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  765. }),
  766. ('TensorSetItemByMixedTensorsWithTensor_2', {
  767. 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float16))),
  768. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  769. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  770. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  771. }),
  772. ('TensorGetItemByMixedTensorsWithTupleOfNumber_2', {
  773. 'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)),
  774. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  775. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  776. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  777. }),
  778. ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
  779. 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float16)),
  780. Tensor(np.zeros((4, 5), np.float16)),
  781. Tensor(np.ones((4, 5), np.float16)))),
  782. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  783. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  784. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  785. }),
  786. ]
  787. raise_error_set = [
  788. ('TensorGetItemByOneTensorDtypeError', {
  789. 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}),
  790. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  791. Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
  792. }),
  793. ('TensorGetItemByTwoTensorsShapeError', {
  794. 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
  795. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  796. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  797. Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
  798. }),
  799. ('TensorGetItemByTwoTensorsDtypeError', {
  800. 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
  801. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  802. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  803. Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
  804. }),
  805. ('TensorGetItemByThreeTensorsShapeError', {
  806. 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
  807. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  808. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  809. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
  810. Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
  811. }),
  812. ('TensorGetItemByThreeTensorsDtypeError', {
  813. 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
  814. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  815. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  816. Tensor(np.random.randint(7, size=(4, 5)), mstype.int64),
  817. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  818. }),
  819. ('TensorGetItemByMixedTensorsNumberError', {
  820. 'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}),
  821. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  822. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  823. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
  824. }),
  825. ('TensorGetItemByMixedTensorsTypeError', {
  826. 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}),
  827. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
  828. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  829. Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
  830. }),
  831. ('TensorGetItemByMixedTensorsDtypeError', {
  832. 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
  833. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
  834. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  835. Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)],
  836. }),
  837. ('TensorGetItemByMixedTensorsShapeError', {
  838. 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
  839. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
  840. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  841. Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)],
  842. }),
  843. ('TensorSetItemByOneTensorWithNumberTypeError', {
  844. 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
  845. 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
  846. }),
  847. ('TensorSetItemByOneTensorWithTensorShapeError', {
  848. 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
  849. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  850. Tensor(np.zeros((6, 7, 8)), mstype.float32)],
  851. }),
  852. ('TensorSetItemByOneTensorWithTensorDtypeError', {
  853. 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
  854. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  855. Tensor(np.zeros((6, 7, 8)), mstype.int32)],
  856. }),
  857. ('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
  858. 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
  859. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  860. }),
  861. ('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
  862. 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
  863. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  864. }),
  865. ('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
  866. 'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
  867. 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
  868. Tensor(np.zeros((8,), np.int32)),
  869. Tensor(np.ones((8,), np.int32)),
  870. Tensor(np.ones((8,), np.float32) * 2)],
  871. }),
  872. ('TensorSetItemByTensorsWithNumberTypeError', {
  873. 'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
  874. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  875. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  876. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  877. }),
  878. ('TensorSetItemByTensorsWithTensorShapeError', {
  879. 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
  880. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  881. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  882. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  883. Tensor(np.zeros((2, 5)), mstype.float32)],
  884. }),
  885. ('TensorSetItemByTensorsWithTensorTypeError', {
  886. 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
  887. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  888. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  889. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  890. Tensor(np.zeros((4, 5)), mstype.int32)],
  891. }),
  892. ('TensorSetItemByTensorsWithTensorNumberError', {
  893. 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
  894. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  895. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  896. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  897. Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
  898. Tensor(np.zeros((2, 5)), mstype.float32)],
  899. }),
  900. ('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
  901. 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}),
  902. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  903. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  904. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  905. }),
  906. ('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
  907. 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
  908. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  909. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  910. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  911. }),
  912. ('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
  913. 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
  914. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  915. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  916. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  917. Tensor(np.zeros((4, 5)), mstype.float32),
  918. Tensor(np.ones((4, 5)), mstype.float32)],
  919. }),
  920. ('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
  921. 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
  922. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  923. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  924. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  925. Tensor(np.zeros((4, 5)), mstype.float32),
  926. Tensor(np.ones((4, 5)), mstype.int32),
  927. Tensor(np.ones((4, 5)) * 2, mstype.int32)],
  928. }),
  929. ('TensorSetItemByMixedTensorsWithNumberValueTypeError', {
  930. 'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}),
  931. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  932. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  933. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  934. }),
  935. ('TensorSetItemByMixedTensorsWithNumberIndexTypeError', {
  936. 'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}),
  937. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  938. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  939. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)],
  940. }),
  941. ('TensorSetItemByMixedTensorsWithTensorValueDtypeError', {
  942. 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))),
  943. {'exception': TypeError}),
  944. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  945. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  946. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  947. }),
  948. ('TensorSetItemByMixedTensorsWithTensorValueShapeError', {
  949. 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))),
  950. {'exception': ValueError}),
  951. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  952. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  953. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  954. }),
  955. ('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', {
  956. 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
  957. {'exception': IndexError}),
  958. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  959. Tensor(np.random.randint(4, size=(4, 5)), mstype.float32),
  960. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  961. }),
  962. ('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', {
  963. 'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)),
  964. {'exception': TypeError}),
  965. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  966. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  967. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  968. }),
  969. ('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', {
  970. 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
  971. Tensor(np.zeros((5, 2, 6), np.float32)),
  972. Tensor(np.ones((5, 2, 6), np.float32)),
  973. Tensor(np.ones((5, 2, 6), np.int32)))),
  974. {'exception': TypeError}),
  975. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  976. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  977. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  978. }),
  979. ('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', {
  980. 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
  981. Tensor(np.zeros((5, 2, 6), np.float32)),
  982. Tensor(np.ones((5, 2, 6), np.float32)),
  983. Tensor(np.ones((5, 2, 6), np.int32)))),
  984. {'exception': IndexError}),
  985. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32),
  986. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  987. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  988. })
  989. ]
  990. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  991. def test_exec():
  992. context.set_context(mode=context.GRAPH_MODE)
  993. return test_cases
  994. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
  995. def test_check_exception():
  996. return raise_error_set
  997. def test_tensor_slice_reduce_out_of_bounds_neg():
  998. class NetWork(Cell):
  999. def __init__(self):
  1000. super(NetWork, self).__init__()
  1001. self.tensor_ret = Tensor(np.array(9, np.int32))
  1002. def construct(self, tensor):
  1003. ret = tensor[-7, 3, 4]
  1004. return ret
  1005. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  1006. net = NetWork()
  1007. with pytest.raises(IndexError):
  1008. net(input_tensor)
  1009. def test_tensor_slice_reduce_out_of_bounds_positive():
  1010. class NetWork(Cell):
  1011. def __init__(self):
  1012. super(NetWork, self).__init__()
  1013. self.tensor_ret = Tensor(np.array(9, np.int32))
  1014. def construct(self, tensor):
  1015. ret = tensor[6, 3, 4]
  1016. return ret
  1017. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  1018. net = NetWork()
  1019. with pytest.raises(IndexError):
  1020. net(input_tensor)