You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_slice.py 53 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_tensor_slice """
  16. import numpy as np
  17. import pytest
  18. from mindspore import Tensor, Parameter
  19. from mindspore import context
  20. from mindspore import dtype as mstype
  21. from mindspore.nn import Cell
  22. from mindspore.ops import operations as P
  23. from ....mindspore_test_framework.mindspore_test import mindspore_test
  24. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  25. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \
  26. pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception
  27. class NetWorkSlicePositive(Cell):
  28. def __init__(self):
  29. super(NetWorkSlicePositive, self).__init__()
  30. self.tensor_ret0 = Tensor(np.ones([1, 2, 2], np.int32))
  31. self.tensor_ret1 = Tensor(np.ones([4, 7, 4], np.int32))
  32. self.tensor_ret2 = Tensor(np.ones([6, 8, 10], np.int32))
  33. self.tensor_ret3 = Tensor(np.ones([3, 8, 10], np.int32))
  34. def construct(self, tensor):
  35. ret0 = tensor[3:4:3, 1:5:2, 3:6:2] + self.tensor_ret0
  36. ret1 = tensor[-6:4:1, 7:-8:-1, ::3] + self.tensor_ret1
  37. ret2 = tensor[::, ::, ::] + self.tensor_ret2
  38. ret3 = tensor[::2] + self.tensor_ret3
  39. return ret0, ret1, ret2, ret3
  40. class NetWorkSliceEllipsis(Cell):
  41. def __init__(self):
  42. super(NetWorkSliceEllipsis, self).__init__()
  43. self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
  44. self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
  45. self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
  46. def construct(self, tensor):
  47. ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
  48. ret1 = tensor[...] + self.tensor_ret1
  49. ret2 = tensor[None] + self.tensor_ret2
  50. ret3 = tensor[True] + self.tensor_ret2
  51. return ret0, ret1, ret2, ret3
  52. class NetWorkReduceDimension(Cell):
  53. def __init__(self):
  54. super(NetWorkReduceDimension, self).__init__()
  55. self.tensor_ret0 = Tensor(np.ones([2, 4, 1], np.int32))
  56. self.tensor_ret1 = Tensor(np.ones([3, 4], np.int32))
  57. self.tensor_ret2 = Tensor(np.ones([6, 8], np.int32))
  58. self.tensor_ret3 = Tensor(np.array(8, np.int32))
  59. self.tensor_ret4 = Tensor(np.ones([8, 10], np.int32))
  60. def construct(self, tensor):
  61. ret0 = tensor[0:6:3, 1:5:1, 3:5:2] + self.tensor_ret0
  62. ret1 = tensor[::2, 1, ::3] + self.tensor_ret1
  63. ret2 = tensor[::, ::, 0] + self.tensor_ret2
  64. ret3 = tensor[3, 2, 5] + self.tensor_ret3
  65. ret4 = tensor[1] + self.tensor_ret4
  66. return ret0, ret1, ret2, ret3, ret4
  67. class NetWorkStepNegative(Cell):
  68. def __init__(self):
  69. super(NetWorkStepNegative, self).__init__()
  70. self.tensor_ret = Tensor(np.ones([6, 5, 10], np.int32))
  71. def construct(self, tensor):
  72. ret = tensor[::1, -5::, ::-1] + self.tensor_ret
  73. return ret
  74. class NetWorkReduceToScalar(Cell):
  75. def __init__(self):
  76. super(NetWorkReduceToScalar, self).__init__()
  77. self.tensor_ret = Tensor(np.array(9, np.int32))
  78. def construct(self, tensor):
  79. ret = tensor[2, 3, 4] + self.tensor_ret
  80. return ret
  81. class TensorAssignWithSliceError1(Cell):
  82. def __init__(self):
  83. super(TensorAssignWithSliceError1, self).__init__()
  84. def construct(self, a, b):
  85. a[1:3:-1, ::] = b
  86. return a
  87. class TensorAssignWithSliceError2(Cell):
  88. def __init__(self):
  89. super(TensorAssignWithSliceError2, self).__init__()
  90. def construct(self, a, b):
  91. a[1:3:-1] = b
  92. return a
  93. class TensorAssignWithSlice2(Cell):
  94. def __init__(self):
  95. super(TensorAssignWithSlice2, self).__init__()
  96. def construct(self, a, b, ck):
  97. a[1:5] = b
  98. a[3:4] = 5
  99. a[-1:1:-1] = b
  100. a[-1:3:-1] = 5
  101. a[::] = b
  102. a[::] = 9
  103. z = a + ck
  104. return z
  105. class TensorAssignWithSlice(Cell):
  106. def __init__(self):
  107. super(TensorAssignWithSlice, self).__init__()
  108. self.c = 2.0
  109. def construct(self, a, b, ck):
  110. a[1:3, ::] = b
  111. a[2:3:, 3:] = b
  112. a[::] = b
  113. a[::] = self.c
  114. a[::, ::] = b
  115. a[::, ::] = self.c
  116. a[2:3:, 0:, 4:1:-1] = b
  117. a[2:3:, 0:, 4:1:-1] = self.c
  118. z = a + ck
  119. return z
  120. class TensorGetItemByOneTensor(Cell):
  121. def __init__(self):
  122. super(TensorGetItemByOneTensor, self).__init__()
  123. self.const = Tensor(np.ones((5, 4, 7, 8)), mstype.int32)
  124. def construct(self, x, index):
  125. ret = x[index] + self.const
  126. return ret
  127. class TensorGetItemByTwoTensors(Cell):
  128. def __init__(self):
  129. super(TensorGetItemByTwoTensors, self).__init__()
  130. self.const = Tensor(np.ones((3, 4, 5, 8)), mstype.int32)
  131. def construct(self, x, index_0, index_1):
  132. ret = x[index_0, index_1] + self.const
  133. return ret
  134. class TensorGetItemByThreeTensors(Cell):
  135. def __init__(self):
  136. super(TensorGetItemByThreeTensors, self).__init__()
  137. self.const = Tensor(np.ones((5, 3, 4, 5)), mstype.int32)
  138. def construct(self, x, index_0, index_1, index_2):
  139. ret = x[index_0, index_1, index_2] + self.const
  140. return ret
  141. class TensorGetItemByMixedTensors_0(Cell):
  142. def __init__(self):
  143. super(TensorGetItemByMixedTensors_0, self).__init__()
  144. self.const = Tensor(np.ones((3, 4, 5, 3, 6, 5), np.float32))
  145. def construct(self, tensor, index_0, index_1):
  146. ret = tensor[index_0, index_1, 0:3, ..., 0:5, 3] + self.const
  147. return ret
  148. class TensorGetItemByMixedTensors_1(Cell):
  149. def __init__(self):
  150. super(TensorGetItemByMixedTensors_1, self).__init__()
  151. self.const = Tensor(np.ones((3, 4, 5, 3, 5, 5), np.float32))
  152. def construct(self, tensor, index_0, index_1):
  153. ret = tensor[0:3, index_0, ..., index_1, 3, 0:5] + self.const
  154. return ret
  155. class TensorGetItemByMixedTensors_2(Cell):
  156. def __init__(self):
  157. super(TensorGetItemByMixedTensors_2, self).__init__()
  158. self.const = Tensor(np.ones((3, 4, 5, 6, 7), np.float32))
  159. def construct(self, tensor, index_0, index_1):
  160. ret = tensor[0, index_0, index_1, ..., 3] + self.const
  161. return ret
  162. class TensorGetItemByMixedTensors_3(Cell):
  163. def __init__(self):
  164. super(TensorGetItemByMixedTensors_3, self).__init__()
  165. self.const = Tensor(np.ones((3, 4, 5, 3, 4, 3, 5), np.float32))
  166. def construct(self, tensor, index_0, index_1):
  167. ret = tensor[..., index_0, 0:3, index_1, 0:5] + self.const
  168. return ret
  169. class TensorGetItemByMixedTensors_4(Cell):
  170. def __init__(self):
  171. super(TensorGetItemByMixedTensors_4, self).__init__()
  172. self.const = Tensor(np.ones((2, 2, 3, 4, 5, 3, 9), np.float32))
  173. def construct(self, tensor, index_0, index_1, index_2):
  174. ret = tensor[0:2, index_0, index_1, 2, index_2, 0:3, ...] + self.const
  175. return ret
  176. class TensorGetItemByMixedTensors_5(Cell):
  177. def __init__(self):
  178. super(TensorGetItemByMixedTensors_5, self).__init__()
  179. self.const = Tensor(np.ones((2, 3, 4, 5, 2, 6), np.float32))
  180. def construct(self, tensor, index_0, index_1, index_2):
  181. ret = tensor[0:2, index_0, index_1, ..., index_2, 2] + self.const
  182. return ret
  183. class TensorGetItemByMixedTensors_6(Cell):
  184. def __init__(self):
  185. super(TensorGetItemByMixedTensors_6, self).__init__()
  186. self.const = Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float32))
  187. def construct(self, tensor, index_0, index_1, index_2):
  188. ret = tensor[..., index_0, index_1, index_2, 3] + self.const
  189. return ret
  190. class TensorSetItemByMixedTensors_0(Cell):
  191. def __init__(self, value):
  192. super(TensorSetItemByMixedTensors_0, self).__init__()
  193. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8, 9), np.float32))
  194. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)),
  195. mstype.float32),
  196. name="x")
  197. self.value = value
  198. def construct(self, index_0, index_1, index_2):
  199. self.param[0:2, index_0, index_1, 2, index_2, 0:3, ...] = self.value
  200. ret = self.param + self.const
  201. return ret
  202. class TensorSetItemByMixedTensors_1(Cell):
  203. def __init__(self, value):
  204. super(TensorSetItemByMixedTensors_1, self).__init__()
  205. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float32))
  206. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  207. name="x")
  208. self.value = value
  209. def construct(self, index_0, index_1, index_2):
  210. self.param[0:2, index_0, index_1, ..., index_2, 2] = self.value
  211. ret = self.param + self.const
  212. return ret
  213. class TensorSetItemByMixedTensors_2(Cell):
  214. def __init__(self, value):
  215. super(TensorSetItemByMixedTensors_2, self).__init__()
  216. self.const = Tensor(np.ones((3, 4, 5, 6, 7, 8), np.float16))
  217. self.param = Parameter(Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float16),
  218. name="x")
  219. self.value = value
  220. def construct(self, index_0, index_1, index_2):
  221. self.param[..., index_0, index_1, index_2, 3] = self.value
  222. ret = self.param + self.const
  223. return ret
  224. class TensorGetItemByMixedTensorsTypeError(Cell):
  225. def __init__(self):
  226. super(TensorGetItemByMixedTensorsTypeError, self).__init__()
  227. def construct(self, x, index_0, index_1):
  228. ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
  229. return ret
  230. class TensorGetItemByMixedTensorsNumberError(Cell):
  231. def __init__(self):
  232. super(TensorGetItemByMixedTensorsNumberError, self).__init__()
  233. def construct(self, x, index_0, index_1):
  234. ret = x[index_0, index_1, 0:3, ..., index_1, index_0]
  235. return ret
  236. class TensorSetItemByOneTensorWithNumber(Cell):
  237. def __init__(self, value):
  238. super(TensorSetItemByOneTensorWithNumber, self).__init__()
  239. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  240. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  241. self.value = value
  242. def construct(self, index):
  243. self.param[index] = self.value
  244. ret = self.param + self.const
  245. return ret
  246. class TensorSetItemByOneTensorWithTensor(Cell):
  247. def __init__(self):
  248. super(TensorSetItemByOneTensorWithTensor, self).__init__()
  249. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  250. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  251. def construct(self, index, value):
  252. self.param[index] = value
  253. ret = self.param + self.const
  254. return ret
  255. class TensorSetItemByOneTensorWithTupleOfNumber(Cell):
  256. def __init__(self, value):
  257. super(TensorSetItemByOneTensorWithTupleOfNumber, self).__init__()
  258. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  259. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  260. self.value = value
  261. def construct(self, index):
  262. self.param[index] = self.value
  263. ret = self.param + self.const
  264. return ret
  265. class TensorSetItemByOneTensorWithTupleOfTensor(Cell):
  266. def __init__(self):
  267. super(TensorSetItemByOneTensorWithTupleOfTensor, self).__init__()
  268. self.const = Tensor(np.ones((6, 3, 8)), mstype.float32)
  269. self.param = Parameter(Tensor(np.arange(6 * 3 * 8).reshape((6, 3, 8)), mstype.float32), name="x")
  270. def construct(self, index, value_0, value_1, value_2):
  271. self.param[index] = (value_0, value_1, value_2)
  272. ret = self.param + self.const
  273. return ret
  274. class TensorSetItemByTensorsWithNumber(Cell):
  275. def __init__(self, value):
  276. super(TensorSetItemByTensorsWithNumber, self).__init__()
  277. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  278. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  279. self.value = value
  280. def construct(self, index_0, index_1, index_2):
  281. self.param[index_0, index_1, index_2] = self.value
  282. ret = self.param + self.const
  283. return ret
  284. class TensorSetItemByTensorsWithTensor(Cell):
  285. def __init__(self):
  286. super(TensorSetItemByTensorsWithTensor, self).__init__()
  287. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  288. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  289. def construct(self, index_0, index_1, index_2, value):
  290. self.param[index_0, index_1, index_2] = value
  291. ret = self.param + self.const
  292. return ret
  293. class TensorSetItemByTensorsWithTensorNumberError(Cell):
  294. def __init__(self):
  295. super(TensorSetItemByTensorsWithTensorNumberError, self).__init__()
  296. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  297. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  298. def construct(self, index_0, index_1, index_2, index_3, value):
  299. self.param[index_0, index_1, index_2, index_3] = value
  300. ret = self.param + self.const
  301. return ret
  302. class TensorSetItemByTensorsWithTupleOfNumber(Cell):
  303. def __init__(self, value):
  304. super(TensorSetItemByTensorsWithTupleOfNumber, self).__init__()
  305. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  306. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  307. self.value = value
  308. def construct(self, index_0, index_1, index_2):
  309. self.param[index_0, index_1, index_2] = self.value
  310. ret = self.param + self.const
  311. return ret
  312. class TensorSetItemByTensorsWithTupleOfTensor(Cell):
  313. def __init__(self):
  314. super(TensorSetItemByTensorsWithTupleOfTensor, self).__init__()
  315. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  316. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  317. def construct(self, index_0, index_1, index_2, value_0, value_1, value_2):
  318. self.param[index_0, index_1, index_2] = (value_0, value_1, value_2)
  319. ret = self.param + self.const
  320. return ret
  321. class TensorSetItemByTensorsWithTupleOfTensorNumberError(Cell):
  322. def __init__(self):
  323. super(TensorSetItemByTensorsWithTupleOfTensorNumberError, self).__init__()
  324. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  325. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  326. def construct(self, index_0, index_1, index_2, value_0, value_1):
  327. self.param[index_0, index_1, index_2] = (value_0, value_1)
  328. ret = self.param + self.const
  329. return ret
  330. class TensorSetItemByMixedTensors(Cell):
  331. def __init__(self):
  332. super(TensorSetItemByMixedTensors, self).__init__()
  333. self.const = Tensor(np.ones((6, 7, 8)), mstype.float32)
  334. self.param = Parameter(Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.float32), name="x")
  335. self.value = 99.0
  336. def construct(self, index_0, index_1):
  337. self.param[index_0, index_1, 0:6] = self.value
  338. ret = self.param + self.const
  339. return ret
  340. def test_tensor_assign():
  341. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  342. net = TensorAssignWithSlice()
  343. net2 = TensorAssignWithSlice2()
  344. net_e1 = TensorAssignWithSliceError1()
  345. net_e2 = TensorAssignWithSliceError2()
  346. a = np.arange(60).reshape(3, 4, 5)
  347. ck = np.arange(60).reshape(3, 4, 5)
  348. b = Tensor([1], dtype=mstype.float32)
  349. Ta = Tensor(a, dtype=mstype.float32)
  350. Tck = Tensor(ck, dtype=mstype.float32)
  351. Ta4d = Tensor(a.reshape(1, 3, 4, 5), dtype=mstype.float32)
  352. Ta4d_ck = Tensor(ck.reshape(1, 3, 4, 5), dtype=mstype.float32)
  353. Tb = Tensor([1, 3], dtype=mstype.float32)
  354. Tc = Tensor([], dtype=mstype.float32)
  355. t = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  356. tck = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  357. net(Ta, b, Tck)
  358. net2(t, b, tck)
  359. # Error for A[Slice] = Number
  360. # 1. A[Slice] = Number, Slice error
  361. with pytest.raises(IndexError):
  362. net_e2(t, Tensor(2, mstype.int32))
  363. # Error for A[Slice] = U, U is a Tensor
  364. # 1. A[Slice] = U, u.size is error
  365. with pytest.raises(ValueError):
  366. net2(t, Tb, tck)
  367. # 2. A[Slice] = U, U is empty
  368. with pytest.raises(ValueError):
  369. net2(t, Tc, tck)
  370. # 3. A[Slice] = U, U.size error
  371. with pytest.raises(ValueError):
  372. net2(t, Tb, tck)
  373. # Error for A[Tuple(Slice...)] = Tensor
  374. # 1. A[Tuple(Slice...)] = U, U is empty
  375. with pytest.raises(ValueError):
  376. net(Ta, Tc, Tck)
  377. # 2. A[Tuple(Slice...)] = U, U.size error
  378. with pytest.raises(ValueError):
  379. net(Ta, Tb, Tck)
  380. # 3. A[Tuple(Slice...)] = U, Slice error
  381. with pytest.raises(IndexError):
  382. net_e1(Ta, b)
  383. # Error for A[Tuple(Slice...)] = Number
  384. # 1. A[Tuple(Slice...)] = Number, Slice error
  385. with pytest.raises(IndexError):
  386. net_e1(Ta, Tensor(2, mstype.int32))
  387. net = TensorAssignWithInteger()
  388. # Error for A[Number] = scalar/Tensor
  389. # 1. A[Number] = U, U is a Tensor, u.size not match
  390. with pytest.raises(ValueError):
  391. net(Ta, Tb, Tck)
  392. with pytest.raises(ValueError):
  393. net(Ta, Tc, Tck)
  394. # 2. A[Number] = U, the number index error
  395. with pytest.raises(IndexError):
  396. net(Ta4d, b, Ta4d_ck)
  397. # Error for A[(n,m)] = scalar/Tensor
  398. # 1. A[(n,m)] = U, U is a tensor. u.size not match
  399. net = TensorAssignWithTupleInteger()
  400. with pytest.raises(ValueError):
  401. net(Ta, Tc, Tck)
  402. with pytest.raises(ValueError):
  403. net(Ta, Tb, Tck)
  404. # 2. A[(n,m)] = U, the number index error
  405. with pytest.raises(IndexError):
  406. net(Ta4d, b, Ta4d_ck)
  407. # Error for A[...] = U or A[1:, ...] = u
  408. # 1. A[...] = scalar/tensor
  409. net = TensorAssignWithEllipsis()
  410. net(Ta, Ta4d)
  411. with pytest.raises(ValueError):
  412. net(Ta, Tc)
  413. with pytest.raises(ValueError):
  414. net(Ta, Tb)
  415. # 2. A[::, 1:, ...] = scalar/tensor
  416. net = TensorAssignWithTupleEllipsis()
  417. net(Ta, b)
  418. Tc = Tensor(1, mstype.float32)
  419. net(Ta, Tc)
  420. with pytest.raises(ValueError):
  421. net(Ta, Tb)
  422. class TensorAssignWithTupleEllipsis2(Cell):
  423. def __init__(self):
  424. super(TensorAssignWithTupleEllipsis2, self).__init__()
  425. def construct(self, a, b):
  426. a[1:, ..., ::] = b
  427. return a
  428. class TensorAssignWithTupleEllipsis(Cell):
  429. def __init__(self):
  430. super(TensorAssignWithTupleEllipsis, self).__init__()
  431. def construct(self, a, b):
  432. a[:2, ...] = 1.0
  433. a[1:, ...] = b
  434. return a
  435. class TensorAssignWithEllipsis(Cell):
  436. def __init__(self):
  437. super(TensorAssignWithEllipsis, self).__init__()
  438. def construct(self, a, b):
  439. a[...] = 1
  440. a[...] = b
  441. return a
  442. class TensorAssignWithInteger(Cell):
  443. def __init__(self):
  444. super(TensorAssignWithInteger, self).__init__()
  445. def construct(self, a, b, ck):
  446. a[1] = 1
  447. a[0] = b
  448. z = a + ck
  449. return z
  450. class TensorAssignWithTupleInteger(Cell):
  451. def __init__(self):
  452. super(TensorAssignWithTupleInteger, self).__init__()
  453. def construct(self, a, b, ck):
  454. a[(1)] = 1.0
  455. a[(1)] = b
  456. a[(1, 1)] = b
  457. a[(1, 1)] = 1.0
  458. z = a + ck
  459. return z
  460. class TensorAssignWithBoolTensorIndex(Cell):
  461. def __init__(self):
  462. super(TensorAssignWithBoolTensorIndex, self).__init__()
  463. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  464. self.u_scalar = 5
  465. def construct(self, a, b, c, u_tensor):
  466. a[c] = self.u_scalar
  467. a[b] = u_tensor
  468. z = a + self.t
  469. return z
  470. class TensorAssignWithBoolTensorIndexError(Cell):
  471. def __init__(self):
  472. super(TensorAssignWithBoolTensorIndexError, self).__init__()
  473. def construct(self, a, b, c, u_tensor):
  474. a[b][c] = u_tensor
  475. return a
  476. class TensorAssignWithBoolTensorIndex2(Cell):
  477. def __init__(self):
  478. super(TensorAssignWithBoolTensorIndex2, self).__init__()
  479. self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float32)
  480. self.t = Tensor(np.arange(60).reshape([3, 4, 5]), dtype=mstype.float32)
  481. self.u_scalar = 5
  482. def construct(self, a, u_tensor):
  483. a[a > 8] = u_tensor
  484. a[a >= 6] = self.u_scalar
  485. a[a < 3] = self.u_scalar
  486. a[a <= 5] = u_tensor
  487. a[a == 5] = self.u_scalar
  488. z = a + self.t
  489. return z
  490. class TensorAssignWithBoolTensorIndex2Error(Cell):
  491. def __init__(self):
  492. super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
  493. def construct(self, a, u_tensor):
  494. a[a > 8][a > 5] = u_tensor
  495. return a
  496. a = np.arange(60).reshape(3, 4, 5)
  497. ck = np.arange(60).reshape(3, 4, 5)
  498. a4 = np.arange(60).reshape(3, 2, 2, 5)
  499. b = a > 5
  500. c = a < 3
  501. Ta = Tensor(a, dtype=mstype.float32)
  502. Tck = Tensor(ck, dtype=mstype.float32)
  503. Ta4 = Tensor(a4, dtype=mstype.float32)
  504. Tb = Tensor(b)
  505. Tc = Tensor(c)
  506. Td = Tensor([True, True])
  507. u_tensor = Tensor([1], dtype=mstype.float32)
  508. u_tensor_error = Tensor([1, 2], dtype=mstype.float32)
  509. t_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  510. tck_1d = Tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=mstype.float32)
  511. u_scalar = 5
  512. def test_tensor_assign_bool_index():
  513. net1 = TensorAssignWithBoolTensorIndex()
  514. net2 = TensorAssignWithBoolTensorIndex2()
  515. net1(Ta, Tb, Tc, u_tensor)
  516. net1(Ta, Tb, Tc, u_tensor)
  517. with pytest.raises(ValueError):
  518. net1(Ta, Td, Tc, u_tensor)
  519. with pytest.raises(IndexError):
  520. net1(Ta, u_tensor, Tc, u_tensor)
  521. with pytest.raises(ValueError):
  522. net1(Ta, Tb, Td, u_tensor)
  523. with pytest.raises(IndexError):
  524. net1(Ta, Tb, Ta, u_tensor)
  525. with pytest.raises(ValueError):
  526. net1(Ta, Tb, Tc, u_tensor_error)
  527. # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
  528. with pytest.raises(ValueError):
  529. net2(Ta, u_tensor_error)
  530. net3 = TensorAssignWithBoolTensorIndexError()
  531. with pytest.raises(IndexError):
  532. net3(Ta, Tb, Tc, u_tensor)
  533. with pytest.raises(IndexError):
  534. net3(Ta, Tb, Tc, Tensor(u_scalar, mstype.int32))
  535. net4 = TensorAssignWithBoolTensorIndex2Error()
  536. with pytest.raises(IndexError):
  537. net4(Ta, u_tensor)
  538. with pytest.raises(IndexError):
  539. net4(Ta, Tensor(u_scalar, mstype.int32))
  540. def test_trivial_call_function_twice_with_diff_key_value_para():
  541. class Net(Cell):
  542. def __init__(self):
  543. super(Net, self).__init__()
  544. self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
  545. self.concat = P.Concat(axis=0)
  546. def compute(self, x, is_decoder):
  547. if is_decoder:
  548. return self.arange[:x]
  549. return self.arange[1:x + 1]
  550. def construct(self):
  551. result1 = self.compute(7, is_decoder=True)
  552. result2 = self.compute(6, is_decoder=False)
  553. return self.concat((result1, result2))
  554. net = Net()
  555. net()
  556. test_cases = [
  557. ('TensorAssignWithTupleEllipsis2', {
  558. 'block': TensorAssignWithTupleEllipsis2(),
  559. 'desc_inputs': [Ta4, u_tensor],
  560. }),
  561. ('TensorAssignWithTupleEllipsis', {
  562. 'block': TensorAssignWithTupleEllipsis(),
  563. 'desc_inputs': [Ta, u_tensor],
  564. }),
  565. ('TensorAssignWithEllipsis', {
  566. 'block': TensorAssignWithEllipsis(),
  567. 'desc_inputs': [Ta, u_tensor],
  568. }),
  569. ('TensorAssignWithTupleInteger', {
  570. 'block': TensorAssignWithTupleInteger(),
  571. 'desc_inputs': [Ta, u_tensor, Tck],
  572. }),
  573. ('TensorAssignWithInteger', {
  574. 'block': TensorAssignWithInteger(),
  575. 'desc_inputs': [Ta, u_tensor, Tck],
  576. }),
  577. ('TensorAssignWithSlice', {
  578. 'block': TensorAssignWithSlice(),
  579. 'desc_inputs': [Ta, u_tensor, Tck],
  580. }),
  581. ('TensorAssignWithSlice2', {
  582. 'block': TensorAssignWithSlice2(),
  583. 'desc_inputs': [t_1d, u_tensor, tck_1d],
  584. }),
  585. ('TensorAssignWithBoolTensorIndex', {
  586. 'block': TensorAssignWithBoolTensorIndex(),
  587. 'desc_inputs': [Ta, Tb, Tc, u_tensor],
  588. }),
  589. ('TensorAssignWithBoolTensorIndex2', {
  590. 'block': TensorAssignWithBoolTensorIndex2(),
  591. 'desc_inputs': [Ta, u_tensor],
  592. }),
  593. ('SlicePositive', {
  594. 'block': NetWorkSlicePositive(),
  595. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  596. }),
  597. ('SliceReduceDimension', {
  598. 'block': NetWorkReduceDimension(),
  599. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  600. }),
  601. ('SliceNegative', {
  602. 'block': NetWorkStepNegative(),
  603. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  604. }),
  605. ('SliceReduceToScalar', {
  606. 'block': NetWorkReduceToScalar(),
  607. 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
  608. }),
  609. ('TensorSliceEllipsis', {
  610. 'block': NetWorkSliceEllipsis(),
  611. 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
  612. }),
  613. ('TensorGetItemByOneTensor', {
  614. 'block': TensorGetItemByOneTensor(),
  615. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  616. Tensor(np.random.randint(6, size=(5, 4)), mstype.int32)],
  617. }),
  618. ('TensorGetItemByTwoTensors', {
  619. 'block': TensorGetItemByTwoTensors(),
  620. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  621. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  622. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32)],
  623. }),
  624. ('TensorGetItemByThreeTensors', {
  625. 'block': TensorGetItemByThreeTensors(),
  626. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  627. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  628. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  629. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  630. }),
  631. ('TensorGetItemByMixedTensors_0', {
  632. 'block': TensorGetItemByMixedTensors_0(),
  633. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  634. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  635. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  636. }),
  637. ('TensorGetItemByMixedTensors_1', {
  638. 'block': TensorGetItemByMixedTensors_1(),
  639. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  640. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  641. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  642. }),
  643. ('TensorGetItemByMixedTensors_2', {
  644. 'block': TensorGetItemByMixedTensors_2(),
  645. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  646. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  647. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  648. }),
  649. ('TensorGetItemByMixedTensors_3', {
  650. 'block': TensorGetItemByMixedTensors_3(),
  651. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  652. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  653. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32)],
  654. }),
  655. ('TensorGetItemByMixedTensors_4', {
  656. 'block': TensorGetItemByMixedTensors_4(),
  657. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.float32),
  658. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  659. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  660. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  661. }),
  662. ('TensorGetItemByMixedTensors_5', {
  663. 'block': TensorGetItemByMixedTensors_5(),
  664. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  665. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  666. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  667. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  668. }),
  669. ('TensorGetItemByMixedTensors_6', {
  670. 'block': TensorGetItemByMixedTensors_6(),
  671. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8).reshape((3, 4, 5, 6, 7, 8)), mstype.float32),
  672. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  673. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  674. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  675. }),
  676. ('TensorSetItemByOneTensorWithNumber', {
  677. 'block': TensorSetItemByOneTensorWithNumber(value=0.0),
  678. 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
  679. }),
  680. ('TensorSetItemByOneTensorWithTensor', {
  681. 'block': TensorSetItemByOneTensorWithTensor(),
  682. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  683. Tensor(np.zeros((4, 7, 8)), mstype.float32)],
  684. }),
  685. ('TensorSetItemByOneTensorWithTupleOfNumber', {
  686. 'block': TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7)),
  687. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  688. }),
  689. ('TensorSetItemByOneTensorWithTupleOfTensor', {
  690. 'block': TensorSetItemByOneTensorWithTupleOfTensor(),
  691. 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
  692. Tensor(np.zeros((8,), np.float32)),
  693. Tensor(np.ones((8,), np.float32)),
  694. Tensor(np.ones((8,), np.float32) * 2)],
  695. }),
  696. ('TensorSetItemByTensorsWithNumber', {
  697. 'block': TensorSetItemByTensorsWithNumber(value=0.0),
  698. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  699. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  700. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  701. }),
  702. ('TensorSetItemByTensorsWithTensor', {
  703. 'block': TensorSetItemByTensorsWithTensor(),
  704. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  705. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  706. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  707. Tensor(np.zeros((4, 5)), mstype.float32)],
  708. }),
  709. ('TensorSetItemByTensorsWithTupleOfNumber', {
  710. 'block': TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.1, 2.2, 3.3, 4.4)),
  711. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  712. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  713. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  714. }),
  715. ('TensorSetItemByTensorsWithTupleOfTensor', {
  716. 'block': TensorSetItemByTensorsWithTupleOfTensor(),
  717. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  718. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  719. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  720. Tensor(np.zeros((4, 5)), mstype.float32),
  721. Tensor(np.ones((4, 5)), mstype.float32),
  722. Tensor(np.ones((4, 5)) * 2, mstype.float32)],
  723. }),
  724. ('TensorSetItemByMixedTensorsWithNumber_0', {
  725. 'block': TensorSetItemByMixedTensors_0(value=88.0),
  726. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  727. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  728. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  729. }),
  730. ('TensorSetItemByMixedTensorsWithTensor_0', {
  731. 'block': TensorSetItemByMixedTensors_0(value=Tensor(np.ones((4, 5, 3, 9), np.float32))),
  732. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  733. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  734. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  735. }),
  736. ('TensorGetItemByMixedTensorsWithTupleOfNumber_0', {
  737. 'block': TensorSetItemByMixedTensors_0(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)),
  738. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  739. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  740. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  741. }),
  742. ('TensorGetItemByMixedTensorsWithTupleOfTensor_0', {
  743. 'block': TensorSetItemByMixedTensors_0(value=(Tensor(np.ones((4, 5, 3, 9), np.float32)),
  744. Tensor(np.zeros((4, 5, 3, 9), np.float32)),
  745. Tensor(np.ones((4, 5, 3, 9), np.float32)))),
  746. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  747. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  748. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  749. }),
  750. ('TensorSetItemByMixedTensorsWithNumber_1', {
  751. 'block': TensorSetItemByMixedTensors_1(value=88.0),
  752. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  753. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  754. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  755. }),
  756. ('TensorSetItemByMixedTensorsWithTensor_1', {
  757. 'block': TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
  758. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  759. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  760. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  761. }),
  762. ('TensorGetItemByMixedTensorsWithTupleOfNumber_1', {
  763. 'block': TensorSetItemByMixedTensors_1(value=(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
  764. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  765. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  766. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  767. }),
  768. ('TensorGetItemByMixedTensorsWithTupleOfTensor_1', {
  769. 'block': TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
  770. Tensor(np.zeros((5, 2, 6), np.float32)),
  771. Tensor(np.ones((5, 2, 6), np.float32)),
  772. Tensor(np.ones((5, 2, 6), np.float32)))),
  773. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  774. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  775. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  776. }),
  777. ('TensorSetItemByMixedTensorsWithNumber_2', {
  778. 'block': TensorSetItemByMixedTensors_2(value=88.0),
  779. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  780. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  781. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  782. }),
  783. ('TensorSetItemByMixedTensorsWithTensor_2', {
  784. 'block': TensorSetItemByMixedTensors_2(value=Tensor(np.ones((3, 4, 2, 3, 4, 5), np.float16))),
  785. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  786. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  787. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  788. }),
  789. ('TensorGetItemByMixedTensorsWithTupleOfNumber_2', {
  790. 'block': TensorSetItemByMixedTensors_2(value=(1.0, 2.0, 3.0, 4.0, 5.0)),
  791. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  792. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  793. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  794. }),
  795. ('TensorGetItemByMixedTensorsWithTupleOfTensor_2', {
  796. 'block': TensorSetItemByMixedTensors_2(value=(Tensor(np.ones((4, 5), np.float16)),
  797. Tensor(np.zeros((4, 5), np.float16)),
  798. Tensor(np.ones((4, 5), np.float16)))),
  799. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  800. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  801. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  802. }),
  803. ]
  804. raise_error_set = [
  805. ('TensorGetItemByOneTensorDtypeError', {
  806. 'block': (TensorGetItemByOneTensor(), {'exception': IndexError}),
  807. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  808. Tensor(np.random.randint(6, size=(5, 4)), mstype.int8)],
  809. }),
  810. ('TensorGetItemByTwoTensorsShapeError', {
  811. 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
  812. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  813. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  814. Tensor(np.random.randint(7, size=(2, 3, 5)), mstype.int32)],
  815. }),
  816. ('TensorGetItemByTwoTensorsDtypeError', {
  817. 'block': (TensorGetItemByTwoTensors(), {'exception': IndexError}),
  818. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  819. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  820. Tensor(np.random.randint(7, size=(4, 5)), mstype.float32)],
  821. }),
  822. ('TensorGetItemByThreeTensorsShapeError', {
  823. 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
  824. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  825. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  826. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32),
  827. Tensor(np.random.randint(8, size=(5, 2, 4, 5)), mstype.int32)],
  828. }),
  829. ('TensorGetItemByThreeTensorsDtypeError', {
  830. 'block': (TensorGetItemByThreeTensors(), {'exception': IndexError}),
  831. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  832. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  833. Tensor(np.random.randint(7, size=(4, 5)), mstype.int64),
  834. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  835. }),
  836. ('TensorGetItemByMixedTensorsNumberError', {
  837. 'block': (TensorGetItemByMixedTensorsNumberError(), {'exception': IndexError}),
  838. 'desc_inputs': [Tensor(np.arange(6 * 7 * 8).reshape((6, 7, 8)), mstype.int32),
  839. Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  840. Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)],
  841. }),
  842. ('TensorGetItemByMixedTensorsTypeError', {
  843. 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}),
  844. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
  845. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  846. Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)],
  847. }),
  848. ('TensorGetItemByMixedTensorsDtypeError', {
  849. 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
  850. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
  851. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  852. Tensor(np.random.randint(4, size=(4, 5)), mstype.float32)],
  853. }),
  854. ('TensorGetItemByMixedTensorsShapeError', {
  855. 'block': (TensorGetItemByMixedTensors_0(), {'exception': IndexError}),
  856. 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32),
  857. Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  858. Tensor(np.random.randint(4, size=(2, 4, 5)), mstype.int32)],
  859. }),
  860. ('TensorSetItemByOneTensorWithNumberTypeError', {
  861. 'block': (TensorSetItemByOneTensorWithNumber(value=0), {'exception': TypeError}),
  862. 'desc_inputs': [Tensor(np.random.randint(4, size=(5, 4)), mstype.int32)],
  863. }),
  864. ('TensorSetItemByOneTensorWithTensorShapeError', {
  865. 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': ValueError}),
  866. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  867. Tensor(np.zeros((6, 7, 8)), mstype.float32)],
  868. }),
  869. ('TensorSetItemByOneTensorWithTensorDtypeError', {
  870. 'block': (TensorSetItemByOneTensorWithTensor(), {'exception': TypeError}),
  871. 'desc_inputs': [Tensor(np.random.randint(3, size=(5, 4)), mstype.int32),
  872. Tensor(np.zeros((6, 7, 8)), mstype.int32)],
  873. }),
  874. ('TensorSetItemByOneTensorWithTupleOfNumberTypeError', {
  875. 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0, 1, 2, 3, 4, 5, 6, 7)), {'exception': TypeError}),
  876. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  877. }),
  878. ('TensorSetItemByOneTensorWithTupleOfNumberNumberError', {
  879. 'block': (TensorSetItemByOneTensorWithTupleOfNumber(value=(0.0, 1.1, 2.2)), {'exception': ValueError}),
  880. 'desc_inputs': [Tensor(np.random.randint(5, size=(5, 4)), mstype.int32)],
  881. }),
  882. ('TensorSetItemByOneTensorWithTupleOfTensorDtyeError', {
  883. 'block': (TensorSetItemByOneTensorWithTupleOfTensor(), {'exception': TypeError}),
  884. 'desc_inputs': [Tensor(np.random.randint(6, size=(5, 4)), mstype.int32),
  885. Tensor(np.zeros((8,), np.int32)),
  886. Tensor(np.ones((8,), np.int32)),
  887. Tensor(np.ones((8,), np.float32) * 2)],
  888. }),
  889. ('TensorSetItemByTensorsWithNumberTypeError', {
  890. 'block': (TensorSetItemByTensorsWithNumber(value=0), {'exception': TypeError}),
  891. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  892. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  893. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  894. }),
  895. ('TensorSetItemByTensorsWithTensorShapeError', {
  896. 'block': (TensorSetItemByTensorsWithTensor(), {'exception': ValueError}),
  897. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  898. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  899. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  900. Tensor(np.zeros((2, 5)), mstype.float32)],
  901. }),
  902. ('TensorSetItemByTensorsWithTensorTypeError', {
  903. 'block': (TensorSetItemByTensorsWithTensor(), {'exception': TypeError}),
  904. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  905. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  906. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  907. Tensor(np.zeros((4, 5)), mstype.int32)],
  908. }),
  909. ('TensorSetItemByTensorsWithTensorNumberError', {
  910. 'block': (TensorSetItemByTensorsWithTensorNumberError(), {'exception': IndexError}),
  911. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  912. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  913. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  914. Tensor(np.random.randint(8, size=(1, 3, 4, 5)), mstype.int32),
  915. Tensor(np.zeros((2, 5)), mstype.float32)],
  916. }),
  917. ('TensorSetItemByTensorsWithTupleOfNumberTypeError', {
  918. 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1, 2, 3, 4)), {'exception': TypeError}),
  919. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  920. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  921. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  922. }),
  923. ('TensorSetItemByTensorsWithTupleOfNumberNumberError', {
  924. 'block': (TensorSetItemByTensorsWithTupleOfNumber(value=(0.0, 1.0, 2.0, 3.0)), {'exception': ValueError}),
  925. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  926. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  927. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32)],
  928. }),
  929. ('TensorSetItemByTensorsWithTupleOfTensorNumberError', {
  930. 'block': (TensorSetItemByTensorsWithTupleOfTensorNumberError(), {'exception': ValueError}),
  931. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  932. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  933. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  934. Tensor(np.zeros((4, 5)), mstype.float32),
  935. Tensor(np.ones((4, 5)), mstype.float32)],
  936. }),
  937. ('TensorSetItemByTensorsWithTupleOfTensorTypeError', {
  938. 'block': (TensorSetItemByTensorsWithTupleOfTensor(), {'exception': TypeError}),
  939. 'desc_inputs': [Tensor(np.random.randint(6, size=(3, 4, 5)), mstype.int32),
  940. Tensor(np.random.randint(7, size=(4, 5)), mstype.int32),
  941. Tensor(np.random.randint(8, size=(5, 3, 4, 5)), mstype.int32),
  942. Tensor(np.zeros((4, 5)), mstype.float32),
  943. Tensor(np.ones((4, 5)), mstype.int32),
  944. Tensor(np.ones((4, 5)) * 2, mstype.int32)],
  945. }),
  946. ('TensorSetItemByMixedTensorsWithNumberValueTypeError', {
  947. 'block': (TensorSetItemByMixedTensors_1(value=88), {'exception': TypeError}),
  948. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  949. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  950. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  951. }),
  952. ('TensorSetItemByMixedTensorsWithNumberIndexTypeError', {
  953. 'block': (TensorSetItemByMixedTensors_1(value=88.0), {'exception': IndexError}),
  954. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  955. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  956. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.float32)],
  957. }),
  958. ('TensorSetItemByMixedTensorsWithTensorValueDtypeError', {
  959. 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.int32))),
  960. {'exception': TypeError}),
  961. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  962. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  963. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  964. }),
  965. ('TensorSetItemByMixedTensorsWithTensorValueShapeError', {
  966. 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((3, 2, 6), np.float32))),
  967. {'exception': ValueError}),
  968. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  969. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  970. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  971. }),
  972. ('TensorSetItemByMixedTensorsWithTensorIndexDtypeError', {
  973. 'block': (TensorSetItemByMixedTensors_1(value=Tensor(np.ones((5, 2, 6), np.float32))),
  974. {'exception': IndexError}),
  975. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  976. Tensor(np.random.randint(4, size=(4, 5)), mstype.float32),
  977. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  978. }),
  979. ('TensorGetItemByMixedTensorsWithTupleOfNumberValueTypeError', {
  980. 'block': (TensorSetItemByMixedTensors_1(value=(1.0, 2, 3.0, 4.0, 5.0, 6.0)),
  981. {'exception': TypeError}),
  982. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  983. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  984. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  985. }),
  986. ('TensorGetItemByMixedTensorsWithTupleOfTensorValueDtypeError', {
  987. 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
  988. Tensor(np.zeros((5, 2, 6), np.float32)),
  989. Tensor(np.ones((5, 2, 6), np.float32)),
  990. Tensor(np.ones((5, 2, 6), np.int32)))),
  991. {'exception': TypeError}),
  992. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32),
  993. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  994. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  995. }),
  996. ('TensorGetItemByMixedTensorsWithTupleOfTensorIndexDtypeError', {
  997. 'block': (TensorSetItemByMixedTensors_1(value=(Tensor(np.ones((5, 2, 6), np.float32)),
  998. Tensor(np.zeros((5, 2, 6), np.float32)),
  999. Tensor(np.ones((5, 2, 6), np.float32)),
  1000. Tensor(np.ones((5, 2, 6), np.int32)))),
  1001. {'exception': IndexError}),
  1002. 'desc_inputs': [Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.float32),
  1003. Tensor(np.random.randint(4, size=(4, 5)), mstype.int32),
  1004. Tensor(np.random.randint(3, size=(2, 1, 4, 5)), mstype.int32)],
  1005. })
  1006. ]
  1007. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
  1008. def test_exec():
  1009. context.set_context(mode=context.GRAPH_MODE)
  1010. return test_cases
  1011. @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
  1012. def test_check_exception():
  1013. return raise_error_set
  1014. def test_tensor_slice_reduce_out_of_bounds_neg():
  1015. class NetWork(Cell):
  1016. def __init__(self):
  1017. super(NetWork, self).__init__()
  1018. self.tensor_ret = Tensor(np.array(9, np.int32))
  1019. def construct(self, tensor):
  1020. ret = tensor[-7, 3, 4]
  1021. return ret
  1022. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  1023. net = NetWork()
  1024. with pytest.raises(ValueError):
  1025. net(input_tensor)
  1026. def test_tensor_slice_reduce_out_of_bounds_positive():
  1027. class NetWork(Cell):
  1028. def __init__(self):
  1029. super(NetWork, self).__init__()
  1030. self.tensor_ret = Tensor(np.array(9, np.int32))
  1031. def construct(self, tensor):
  1032. ret = tensor[6, 3, 4]
  1033. return ret
  1034. input_tensor = Tensor(np.ones([6, 8, 10], np.int32))
  1035. net = NetWork()
  1036. with pytest.raises(ValueError):
  1037. net(input_tensor)