You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_repeat_elements_op.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. from mindspore import Tensor
  18. from mindspore.ops import composite as C
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. class RepeatElementsNet(nn.Cell):
  22. def __init__(self, rep, axis):
  23. super(RepeatElementsNet, self).__init__()
  24. self.rep = rep
  25. self.axis = axis
  26. def construct(self, x):
  27. return C.repeat_elements(x, self.rep, self.axis)
  28. def repeat_elements(x, rep, axis):
  29. repeat_elements_net = RepeatElementsNet(rep, axis)
  30. return repeat_elements_net(Tensor(x.astype(np.int32))).asnumpy()
  31. @pytest.mark.level0
  32. @pytest.mark.platform_x86_gpu_training
  33. @pytest.mark.env_onecard
  34. def test_repeat_elements_1d_one_element_rep_1():
  35. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  36. a = np.arange(1)
  37. ms_out = repeat_elements(a, 1, 0)
  38. np_out = a.repeat(1, 0)
  39. np.testing.assert_array_equal(np_out, ms_out)
  40. @pytest.mark.level0
  41. @pytest.mark.platform_x86_gpu_training
  42. @pytest.mark.env_onecard
  43. def test_repeat_elements_1d_one_element_rep_many():
  44. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  45. a = np.arange(1)
  46. ms_out = repeat_elements(a, 5, 0)
  47. np_out = a.repeat(5, 0)
  48. np.testing.assert_array_equal(np_out, ms_out)
  49. ms_out = repeat_elements(a, 513, 0)
  50. np_out = a.repeat(513, 0)
  51. np.testing.assert_array_equal(np_out, ms_out)
  52. @pytest.mark.level0
  53. @pytest.mark.platform_x86_gpu_training
  54. @pytest.mark.env_onecard
  55. def test_repeat_elements_1d_rep_1():
  56. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  57. a = np.arange(24)
  58. ms_out = repeat_elements(a, 1, 0)
  59. np_out = a.repeat(1, 0)
  60. np.testing.assert_array_equal(np_out, ms_out)
  61. @pytest.mark.level0
  62. @pytest.mark.platform_x86_gpu_training
  63. @pytest.mark.env_onecard
  64. def test_repeat_elements_1d_rep_many():
  65. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  66. a = np.arange(24)
  67. ms_out = repeat_elements(a, 231, 0)
  68. np_out = a.repeat(231, 0)
  69. np.testing.assert_array_equal(np_out, ms_out)
  70. @pytest.mark.level0
  71. @pytest.mark.platform_x86_gpu_training
  72. @pytest.mark.env_onecard
  73. def test_repeat_elements_2d_one_element_rep_1():
  74. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  75. a = np.arange(1).reshape(1, 1)
  76. ms_out = repeat_elements(a, 1, 0)
  77. np_out = a.repeat(1, 0)
  78. np.testing.assert_array_equal(np_out, ms_out)
  79. ms_out = repeat_elements(a, 1, 1)
  80. np_out = a.repeat(1, 1)
  81. np.testing.assert_array_equal(np_out, ms_out)
  82. @pytest.mark.level0
  83. @pytest.mark.platform_x86_gpu_training
  84. @pytest.mark.env_onecard
  85. def test_repeat_elements_2d_one_element_rep_many():
  86. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  87. a = np.arange(1).reshape(1, 1)
  88. ms_out = repeat_elements(a, 13, 0)
  89. np_out = a.repeat(13, 0)
  90. np.testing.assert_array_equal(np_out, ms_out)
  91. ms_out = repeat_elements(a, 13, 1)
  92. np_out = a.repeat(13, 1)
  93. np.testing.assert_array_equal(np_out, ms_out)
  94. @pytest.mark.level0
  95. @pytest.mark.platform_x86_gpu_training
  96. @pytest.mark.env_onecard
  97. def test_repeat_elements_2d_rep_1():
  98. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  99. a = np.arange(24).reshape(12, 2)
  100. ms_out = repeat_elements(a, 1, 0)
  101. np_out = a.repeat(1, 0)
  102. np.testing.assert_array_equal(np_out, ms_out)
  103. ms_out = repeat_elements(a, 1, 1)
  104. np_out = a.repeat(1, 1)
  105. np.testing.assert_array_equal(np_out, ms_out)
  106. @pytest.mark.level0
  107. @pytest.mark.platform_x86_gpu_training
  108. @pytest.mark.env_onecard
  109. def test_repeat_elements_2d_rep_many():
  110. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  111. a = np.arange(24).reshape(8, 3)
  112. ms_out = repeat_elements(a, 23, 0)
  113. np_out = a.repeat(23, 0)
  114. np.testing.assert_array_equal(np_out, ms_out)
  115. ms_out = repeat_elements(a, 23, 1)
  116. np_out = a.repeat(23, 1)
  117. np.testing.assert_array_equal(np_out, ms_out)
  118. @pytest.mark.level0
  119. @pytest.mark.platform_x86_gpu_training
  120. @pytest.mark.env_onecard
  121. def test_repeat_elements_3d_one_element_rep_1():
  122. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  123. a = np.arange(1).reshape(1, 1, 1)
  124. ms_out = repeat_elements(a, 1, 0)
  125. np_out = a.repeat(1, 0)
  126. np.testing.assert_array_equal(np_out, ms_out)
  127. ms_out = repeat_elements(a, 1, 1)
  128. np_out = a.repeat(1, 1)
  129. np.testing.assert_array_equal(np_out, ms_out)
  130. ms_out = repeat_elements(a, 1, 2)
  131. np_out = a.repeat(1, 2)
  132. np.testing.assert_array_equal(np_out, ms_out)
  133. @pytest.mark.level0
  134. @pytest.mark.platform_x86_gpu_training
  135. @pytest.mark.env_onecard
  136. def test_repeat_elements_3d_one_element_rep_many():
  137. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  138. a = np.arange(1).reshape(1, 1, 1)
  139. ms_out = repeat_elements(a, 43, 0)
  140. np_out = a.repeat(43, 0)
  141. np.testing.assert_array_equal(np_out, ms_out)
  142. ms_out = repeat_elements(a, 43, 1)
  143. np_out = a.repeat(43, 1)
  144. np.testing.assert_array_equal(np_out, ms_out)
  145. ms_out = repeat_elements(a, 43, 2)
  146. np_out = a.repeat(43, 2)
  147. np.testing.assert_array_equal(np_out, ms_out)
  148. @pytest.mark.level0
  149. @pytest.mark.platform_x86_gpu_training
  150. @pytest.mark.env_onecard
  151. def test_repeat_elements_3d_rep_1():
  152. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  153. a = np.arange(60).reshape(6, 2, 5)
  154. ms_out = repeat_elements(a, 1, 0)
  155. np_out = a.repeat(1, 0)
  156. np.testing.assert_array_equal(np_out, ms_out)
  157. ms_out = repeat_elements(a, 1, 1)
  158. np_out = a.repeat(1, 1)
  159. np.testing.assert_array_equal(np_out, ms_out)
  160. ms_out = repeat_elements(a, 1, 2)
  161. np_out = a.repeat(1, 2)
  162. np.testing.assert_array_equal(np_out, ms_out)
  163. @pytest.mark.level0
  164. @pytest.mark.platform_x86_gpu_training
  165. @pytest.mark.env_onecard
  166. def test_repeat_elements_3d_rep_many():
  167. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  168. a = np.arange(60).reshape(3, 4, 5)
  169. ms_out = repeat_elements(a, 14, 0)
  170. np_out = a.repeat(14, 0)
  171. np.testing.assert_array_equal(np_out, ms_out)
  172. ms_out = repeat_elements(a, 14, 1)
  173. np_out = a.repeat(14, 1)
  174. np.testing.assert_array_equal(np_out, ms_out)
  175. ms_out = repeat_elements(a, 14, 2)
  176. np_out = a.repeat(14, 2)
  177. np.testing.assert_array_equal(np_out, ms_out)
  178. @pytest.mark.level0
  179. @pytest.mark.platform_x86_gpu_training
  180. @pytest.mark.env_onecard
  181. def test_repeat_elements_4d_one_element_rep_1():
  182. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  183. a = np.arange(1).reshape(1, 1, 1, 1)
  184. ms_out = repeat_elements(a, 1, 0)
  185. np_out = a.repeat(1, 0)
  186. np.testing.assert_array_equal(np_out, ms_out)
  187. ms_out = repeat_elements(a, 1, 1)
  188. np_out = a.repeat(1, 1)
  189. np.testing.assert_array_equal(np_out, ms_out)
  190. ms_out = repeat_elements(a, 1, 2)
  191. np_out = a.repeat(1, 2)
  192. np.testing.assert_array_equal(np_out, ms_out)
  193. ms_out = repeat_elements(a, 1, 3)
  194. np_out = a.repeat(1, 3)
  195. np.testing.assert_array_equal(np_out, ms_out)
  196. @pytest.mark.level0
  197. @pytest.mark.platform_x86_gpu_training
  198. @pytest.mark.env_onecard
  199. def test_repeat_elements_4d_one_element_rep_many():
  200. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  201. a = np.arange(1).reshape(1, 1, 1, 1)
  202. ms_out = repeat_elements(a, 17, 0)
  203. np_out = a.repeat(17, 0)
  204. np.testing.assert_array_equal(np_out, ms_out)
  205. ms_out = repeat_elements(a, 17, 1)
  206. np_out = a.repeat(17, 1)
  207. np.testing.assert_array_equal(np_out, ms_out)
  208. ms_out = repeat_elements(a, 17, 2)
  209. np_out = a.repeat(17, 2)
  210. np.testing.assert_array_equal(np_out, ms_out)
  211. ms_out = repeat_elements(a, 17, 3)
  212. np_out = a.repeat(17, 3)
  213. np.testing.assert_array_equal(np_out, ms_out)
  214. @pytest.mark.level0
  215. @pytest.mark.platform_x86_gpu_training
  216. @pytest.mark.env_onecard
  217. def test_repeat_elements_4d_rep_1():
  218. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  219. a = np.arange(24).reshape(4, 3, 2, 1)
  220. ms_out = repeat_elements(a, 1, 0)
  221. np_out = a.repeat(1, 0)
  222. np.testing.assert_array_equal(np_out, ms_out)
  223. ms_out = repeat_elements(a, 1, 1)
  224. np_out = a.repeat(1, 1)
  225. np.testing.assert_array_equal(np_out, ms_out)
  226. ms_out = repeat_elements(a, 1, 2)
  227. np_out = a.repeat(1, 2)
  228. np.testing.assert_array_equal(np_out, ms_out)
  229. ms_out = repeat_elements(a, 1, 3)
  230. np_out = a.repeat(1, 3)
  231. np.testing.assert_array_equal(np_out, ms_out)
  232. @pytest.mark.level0
  233. @pytest.mark.platform_x86_gpu_training
  234. @pytest.mark.env_onecard
  235. def test_repeat_elements_4d_rep_many():
  236. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  237. a = np.arange(24).reshape(2, 2, 2, 3)
  238. ms_out = repeat_elements(a, 23, 0)
  239. np_out = a.repeat(23, 0)
  240. np.testing.assert_array_equal(np_out, ms_out)
  241. ms_out = repeat_elements(a, 23, 1)
  242. np_out = a.repeat(23, 1)
  243. np.testing.assert_array_equal(np_out, ms_out)
  244. ms_out = repeat_elements(a, 23, 2)
  245. np_out = a.repeat(23, 2)
  246. np.testing.assert_array_equal(np_out, ms_out)
  247. ms_out = repeat_elements(a, 23, 3)
  248. np_out = a.repeat(23, 3)
  249. np.testing.assert_array_equal(np_out, ms_out)
  250. @pytest.mark.level0
  251. @pytest.mark.platform_x86_gpu_training
  252. @pytest.mark.env_onecard
  253. def test_repeat_elements_5d_one_element_rep_1():
  254. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  255. a = np.arange(1).reshape(1, 1, 1, 1, 1)
  256. ms_out = repeat_elements(a, 1, 0)
  257. np_out = a.repeat(1, 0)
  258. np.testing.assert_array_equal(np_out, ms_out)
  259. ms_out = repeat_elements(a, 1, 1)
  260. np_out = a.repeat(1, 1)
  261. np.testing.assert_array_equal(np_out, ms_out)
  262. ms_out = repeat_elements(a, 1, 2)
  263. np_out = a.repeat(1, 2)
  264. np.testing.assert_array_equal(np_out, ms_out)
  265. ms_out = repeat_elements(a, 1, 3)
  266. np_out = a.repeat(1, 3)
  267. np.testing.assert_array_equal(np_out, ms_out)
  268. ms_out = repeat_elements(a, 1, 4)
  269. np_out = a.repeat(1, 4)
  270. np.testing.assert_array_equal(np_out, ms_out)
  271. @pytest.mark.level0
  272. @pytest.mark.platform_x86_gpu_training
  273. @pytest.mark.env_onecard
  274. def test_repeat_elements_5d_one_element_rep_many():
  275. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  276. a = np.arange(1).reshape(1, 1, 1, 1, 1)
  277. ms_out = repeat_elements(a, 19, 0)
  278. np_out = a.repeat(19, 0)
  279. np.testing.assert_array_equal(np_out, ms_out)
  280. ms_out = repeat_elements(a, 19, 1)
  281. np_out = a.repeat(19, 1)
  282. np.testing.assert_array_equal(np_out, ms_out)
  283. ms_out = repeat_elements(a, 19, 2)
  284. np_out = a.repeat(19, 2)
  285. np.testing.assert_array_equal(np_out, ms_out)
  286. ms_out = repeat_elements(a, 19, 3)
  287. np_out = a.repeat(19, 3)
  288. np.testing.assert_array_equal(np_out, ms_out)
  289. ms_out = repeat_elements(a, 19, 4)
  290. np_out = a.repeat(19, 4)
  291. np.testing.assert_array_equal(np_out, ms_out)
  292. @pytest.mark.level0
  293. @pytest.mark.platform_x86_gpu_training
  294. @pytest.mark.env_onecard
  295. def test_repeat_elements_5d_rep_1():
  296. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  297. a = np.arange(224).reshape(8, 2, 1, 7, 2)
  298. ms_out = repeat_elements(a, 1, 0)
  299. np_out = a.repeat(1, 0)
  300. np.testing.assert_array_equal(np_out, ms_out)
  301. ms_out = repeat_elements(a, 1, 1)
  302. np_out = a.repeat(1, 1)
  303. np.testing.assert_array_equal(np_out, ms_out)
  304. ms_out = repeat_elements(a, 1, 2)
  305. np_out = a.repeat(1, 2)
  306. np.testing.assert_array_equal(np_out, ms_out)
  307. ms_out = repeat_elements(a, 1, 3)
  308. np_out = a.repeat(1, 3)
  309. np.testing.assert_array_equal(np_out, ms_out)
  310. ms_out = repeat_elements(a, 1, 4)
  311. np_out = a.repeat(1, 4)
  312. np.testing.assert_array_equal(np_out, ms_out)
  313. @pytest.mark.level0
  314. @pytest.mark.platform_x86_gpu_training
  315. @pytest.mark.env_onecard
  316. def test_repeat_elements_5d_rep_many():
  317. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  318. a = np.arange(224).reshape(1, 7, 4, 4, 2)
  319. ms_out = repeat_elements(a, 7, 0)
  320. np_out = a.repeat(7, 0)
  321. np.testing.assert_array_equal(np_out, ms_out)
  322. ms_out = repeat_elements(a, 7, 1)
  323. np_out = a.repeat(7, 1)
  324. np.testing.assert_array_equal(np_out, ms_out)
  325. ms_out = repeat_elements(a, 7, 2)
  326. np_out = a.repeat(7, 2)
  327. np.testing.assert_array_equal(np_out, ms_out)
  328. ms_out = repeat_elements(a, 7, 3)
  329. np_out = a.repeat(7, 3)
  330. np.testing.assert_array_equal(np_out, ms_out)
  331. ms_out = repeat_elements(a, 7, 4)
  332. np_out = a.repeat(7, 4)
  333. np.testing.assert_array_equal(np_out, ms_out)
  334. @pytest.mark.level0
  335. @pytest.mark.platform_x86_gpu_training
  336. @pytest.mark.env_onecard
  337. def test_repeat_elements_large_one_element_rep_1():
  338. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  339. a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1)
  340. ms_out = repeat_elements(a, 1, 0)
  341. np_out = a.repeat(1, 0)
  342. np.testing.assert_array_equal(np_out, ms_out)
  343. ms_out = repeat_elements(a, 1, 1)
  344. np_out = a.repeat(1, 1)
  345. np.testing.assert_array_equal(np_out, ms_out)
  346. ms_out = repeat_elements(a, 1, 2)
  347. np_out = a.repeat(1, 2)
  348. np.testing.assert_array_equal(np_out, ms_out)
  349. ms_out = repeat_elements(a, 1, 3)
  350. np_out = a.repeat(1, 3)
  351. np.testing.assert_array_equal(np_out, ms_out)
  352. ms_out = repeat_elements(a, 1, 4)
  353. np_out = a.repeat(1, 4)
  354. np.testing.assert_array_equal(np_out, ms_out)
  355. ms_out = repeat_elements(a, 1, 5)
  356. np_out = a.repeat(1, 5)
  357. np.testing.assert_array_equal(np_out, ms_out)
  358. ms_out = repeat_elements(a, 1, 6)
  359. np_out = a.repeat(1, 6)
  360. np.testing.assert_array_equal(np_out, ms_out)
  361. ms_out = repeat_elements(a, 1, 7)
  362. np_out = a.repeat(1, 7)
  363. np.testing.assert_array_equal(np_out, ms_out)
  364. @pytest.mark.level0
  365. @pytest.mark.platform_x86_gpu_training
  366. @pytest.mark.env_onecard
  367. def test_repeat_elements_large_one_element_rep_many():
  368. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  369. a = np.arange(1).reshape(1, 1, 1, 1, 1, 1, 1, 1)
  370. ms_out = repeat_elements(a, 42, 0)
  371. np_out = a.repeat(42, 0)
  372. np.testing.assert_array_equal(np_out, ms_out)
  373. ms_out = repeat_elements(a, 42, 1)
  374. np_out = a.repeat(42, 1)
  375. np.testing.assert_array_equal(np_out, ms_out)
  376. ms_out = repeat_elements(a, 42, 2)
  377. np_out = a.repeat(42, 2)
  378. np.testing.assert_array_equal(np_out, ms_out)
  379. ms_out = repeat_elements(a, 42, 3)
  380. np_out = a.repeat(42, 3)
  381. np.testing.assert_array_equal(np_out, ms_out)
  382. ms_out = repeat_elements(a, 42, 4)
  383. np_out = a.repeat(42, 4)
  384. np.testing.assert_array_equal(np_out, ms_out)
  385. ms_out = repeat_elements(a, 42, 5)
  386. np_out = a.repeat(42, 5)
  387. np.testing.assert_array_equal(np_out, ms_out)
  388. ms_out = repeat_elements(a, 42, 6)
  389. np_out = a.repeat(42, 6)
  390. np.testing.assert_array_equal(np_out, ms_out)
  391. ms_out = repeat_elements(a, 42, 7)
  392. np_out = a.repeat(42, 7)
  393. np.testing.assert_array_equal(np_out, ms_out)
  394. @pytest.mark.level0
  395. @pytest.mark.platform_x86_gpu_training
  396. @pytest.mark.env_onecard
  397. def test_repeat_elements_large_rep_1():
  398. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  399. a = np.arange(1152).reshape(2, 3, 4, 8, 1, 1, 2, 3)
  400. ms_out = repeat_elements(a, 1, 0)
  401. np_out = a.repeat(1, 0)
  402. np.testing.assert_array_equal(np_out, ms_out)
  403. ms_out = repeat_elements(a, 1, 1)
  404. np_out = a.repeat(1, 1)
  405. np.testing.assert_array_equal(np_out, ms_out)
  406. ms_out = repeat_elements(a, 1, 2)
  407. np_out = a.repeat(1, 2)
  408. np.testing.assert_array_equal(np_out, ms_out)
  409. ms_out = repeat_elements(a, 1, 3)
  410. np_out = a.repeat(1, 3)
  411. np.testing.assert_array_equal(np_out, ms_out)
  412. ms_out = repeat_elements(a, 1, 4)
  413. np_out = a.repeat(1, 4)
  414. np.testing.assert_array_equal(np_out, ms_out)
  415. ms_out = repeat_elements(a, 1, 5)
  416. np_out = a.repeat(1, 5)
  417. np.testing.assert_array_equal(np_out, ms_out)
  418. ms_out = repeat_elements(a, 1, 6)
  419. np_out = a.repeat(1, 6)
  420. np.testing.assert_array_equal(np_out, ms_out)
  421. ms_out = repeat_elements(a, 1, 7)
  422. np_out = a.repeat(1, 7)
  423. np.testing.assert_array_equal(np_out, ms_out)
  424. @pytest.mark.level0
  425. @pytest.mark.platform_x86_gpu_training
  426. @pytest.mark.env_onecard
  427. def test_repeat_elements_large_rep_many():
  428. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  429. a = np.arange(1152).reshape(4, 3, 4, 2, 1, 1, 4, 3)
  430. ms_out = repeat_elements(a, 4, 0)
  431. np_out = a.repeat(4, 0)
  432. np.testing.assert_array_equal(np_out, ms_out)
  433. ms_out = repeat_elements(a, 4, 1)
  434. np_out = a.repeat(4, 1)
  435. np.testing.assert_array_equal(np_out, ms_out)
  436. ms_out = repeat_elements(a, 4, 2)
  437. np_out = a.repeat(4, 2)
  438. np.testing.assert_array_equal(np_out, ms_out)
  439. ms_out = repeat_elements(a, 4, 3)
  440. np_out = a.repeat(4, 3)
  441. np.testing.assert_array_equal(np_out, ms_out)
  442. ms_out = repeat_elements(a, 4, 4)
  443. np_out = a.repeat(4, 4)
  444. np.testing.assert_array_equal(np_out, ms_out)
  445. ms_out = repeat_elements(a, 4, 5)
  446. np_out = a.repeat(4, 5)
  447. np.testing.assert_array_equal(np_out, ms_out)
  448. ms_out = repeat_elements(a, 4, 6)
  449. np_out = a.repeat(4, 6)
  450. np.testing.assert_array_equal(np_out, ms_out)
  451. ms_out = repeat_elements(a, 4, 7)
  452. np_out = a.repeat(4, 7)
  453. np.testing.assert_array_equal(np_out, ms_out)
  454. @pytest.mark.level0
  455. @pytest.mark.platform_x86_gpu_training
  456. @pytest.mark.env_onecard
  457. def test_repeat_elements_half():
  458. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  459. a = np.arange(1152).astype(np.float16).reshape(4, 3, 4, 2, 1, 1, 4, 3)
  460. ms_out = repeat_elements(a, 4, 0)
  461. np_out = a.repeat(4, 0)
  462. np.testing.assert_array_equal(np_out, ms_out)
  463. ms_out = repeat_elements(a, 4, 1)
  464. np_out = a.repeat(4, 1)
  465. np.testing.assert_array_equal(np_out, ms_out)
  466. ms_out = repeat_elements(a, 4, 2)
  467. np_out = a.repeat(4, 2)
  468. np.testing.assert_array_equal(np_out, ms_out)
  469. ms_out = repeat_elements(a, 4, 3)
  470. np_out = a.repeat(4, 3)
  471. np.testing.assert_array_equal(np_out, ms_out)
  472. ms_out = repeat_elements(a, 4, 4)
  473. np_out = a.repeat(4, 4)
  474. np.testing.assert_array_equal(np_out, ms_out)
  475. ms_out = repeat_elements(a, 4, 5)
  476. np_out = a.repeat(4, 5)
  477. np.testing.assert_array_equal(np_out, ms_out)
  478. ms_out = repeat_elements(a, 4, 6)
  479. np_out = a.repeat(4, 6)
  480. np.testing.assert_array_equal(np_out, ms_out)
  481. ms_out = repeat_elements(a, 4, 7)
  482. np_out = a.repeat(4, 7)
  483. np.testing.assert_array_equal(np_out, ms_out)
  484. @pytest.mark.level0
  485. @pytest.mark.platform_x86_gpu_training
  486. @pytest.mark.env_onecard
  487. def test_repeat_elements_net_multi_use():
  488. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  489. rep = 3
  490. axis = 4
  491. repeat_elements_net = RepeatElementsNet(rep, axis)
  492. a = np.arange(64).reshape(2, 2, 2, 2, 2, 2)
  493. ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy()
  494. np_out = a.repeat(rep, axis)
  495. np.testing.assert_array_equal(np_out, ms_out)
  496. a = np.arange(128).reshape(2, 2, 4, 2, 2, 2)
  497. ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy()
  498. np_out = a.repeat(rep, axis)
  499. np.testing.assert_array_equal(np_out, ms_out)
  500. a = np.arange(18).reshape(1, 1, 3, 2, 3, 1)
  501. ms_out = repeat_elements_net(Tensor(a.astype(np.int32))).asnumpy()
  502. np_out = a.repeat(rep, axis)
  503. np.testing.assert_array_equal(np_out, ms_out)
  504. @pytest.mark.level0
  505. @pytest.mark.platform_x86_gpu_training
  506. @pytest.mark.env_onecard
  507. def test_repeat_elements_invalid_input():
  508. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  509. a = np.arange(64).reshape(2, 2, 2, 2, 2, 2)
  510. with pytest.raises(ValueError):
  511. _ = repeat_elements(a, 0, 0)
  512. with pytest.raises(ValueError):
  513. _ = repeat_elements(a, 1, 6)
  514. with pytest.raises(ValueError):
  515. _ = repeat_elements(a, 1, -7)