You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_neighborexchangev2.py 30 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore as ms
  18. import mindspore.context as context
  19. from mindspore import Tensor
  20. import mindspore.nn as nn
  21. from mindspore.common.api import _cell_graph_executor
  22. from mindspore.nn import TrainOneStepCell, Momentum
  23. from mindspore.ops.operations.comm_ops import NeighborExchangeV2
  24. _x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32)
  25. _x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32)
  26. def compile_net(net, x1, x2):
  27. context.set_context(mode=context.GRAPH_MODE)
  28. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  29. train_net = TrainOneStepCell(net, optimizer)
  30. train_net.set_train()
  31. _cell_graph_executor.compile(train_net, x1, x2)
  32. def test_neighborexchangev2_single_input_success():
  33. """
  34. Feature: NeighborExchangeV2
  35. Description: one inputs and one outputs, with valid arguments
  36. Expectation: success
  37. """
  38. context.set_auto_parallel_context(device_num=8, global_rank=0)
  39. class Net(nn.Cell):
  40. def __init__(self):
  41. super(Net, self).__init__()
  42. self.linear = nn.Dense(16, 16)
  43. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  44. send_lens=[0, 1, 0, 0],
  45. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  46. recv_lens=[0, 1, 0, 0], data_format="NCHW")
  47. def construct(self, x1, x2):
  48. y = self.linear(x1)
  49. y = self.neighborexchangev2(y)
  50. y = y + x2
  51. return y
  52. net = Net()
  53. compile_net(net, _x1, _x2)
  54. def test_neighborexchangev2_send_lens_equal_to_input_shape_success():
  55. """
  56. Feature: NeighborExchangeV2
  57. Description: send_lens is equal to input shape
  58. Expectation: success
  59. """
  60. context.set_auto_parallel_context(device_num=8, global_rank=0)
  61. class Net(nn.Cell):
  62. def __init__(self):
  63. super(Net, self).__init__()
  64. self.linear = nn.Dense(16, 16)
  65. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  66. send_lens=[0, 32, 0, 0],
  67. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  68. recv_lens=[0, 1, 0, 0], data_format="NCHW")
  69. def construct(self, x1, x2):
  70. y = self.linear(x1)
  71. y = self.neighborexchangev2(y)
  72. y = y + x2
  73. return y
  74. net = Net()
  75. compile_net(net, _x1, _x2)
  76. def test_neighborexchangev2_empty_send_success():
  77. """
  78. Feature: NeighborExchangeV2
  79. Description: empty inputs, with valid arguments
  80. Expectation: success
  81. """
  82. context.set_auto_parallel_context(device_num=8, global_rank=0)
  83. class Net(nn.Cell):
  84. def __init__(self):
  85. super(Net, self).__init__()
  86. self.linear = nn.Dense(16, 16)
  87. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  88. send_lens=[1, 2, 3, 4],
  89. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  90. recv_lens=[0, 1, 0, 0],
  91. data_format="NCHW")
  92. def construct(self, x1, x2):
  93. y = self.linear(x1)
  94. y = self.neighborexchangev2(y)
  95. y = y + x2
  96. return y
  97. net = Net()
  98. compile_net(net, _x1, _x2)
  99. def test_neighborexchangev2_empty_recv_success():
  100. """
  101. Feature: NeighborExchangeV2
  102. Description: empty outputs, with valid arguments
  103. Expectation: success
  104. """
  105. context.set_auto_parallel_context(device_num=8, global_rank=0)
  106. class Net(nn.Cell):
  107. def __init__(self):
  108. super(Net, self).__init__()
  109. self.linear = nn.Dense(16, 16)
  110. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  111. send_lens=[0, 1, 0, 0],
  112. recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  113. recv_lens=[1, 2, 3, 4],
  114. data_format="NCHW")
  115. def construct(self, x1, x2):
  116. y = self.linear(x1)
  117. y = self.neighborexchangev2(y)
  118. y = y + x2
  119. return y
  120. net = Net()
  121. compile_net(net, _x1, _x1)
  122. def test_neighborexchangev2_empty_send_empty_recv_success():
  123. """
  124. Feature: NeighborExchangeV2
  125. Description: empty inputs and empty outputs, with valid arguments
  126. Expectation: success
  127. """
  128. context.set_auto_parallel_context(device_num=8, global_rank=0)
  129. class Net(nn.Cell):
  130. def __init__(self):
  131. super(Net, self).__init__()
  132. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  133. send_lens=[0, 1, 0, 0],
  134. recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  135. recv_lens=[1, 2, 3, 4],
  136. data_format="NCHW")
  137. def construct(self, x1):
  138. y = self.neighborexchangev2(x1)
  139. return y
  140. net = Net()
  141. _cell_graph_executor.compile(net, _x1)
  142. def test_neighborexchangev2_invalid_dataformat_failed():
  143. """
  144. Feature: NeighborExchangeV2
  145. Description: data_format should be NCHW, but gives NHWC
  146. Expectation: throw ValueError
  147. """
  148. context.set_auto_parallel_context(device_num=8, global_rank=0)
  149. class Net(nn.Cell):
  150. def __init__(self):
  151. super(Net, self).__init__()
  152. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  153. send_lens=[0, 1, 0, 0],
  154. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  155. recv_lens=[0, 1, 0, 0],
  156. data_format="NHWC")
  157. def construct(self, x):
  158. out = self.neighborexchangev2(x)
  159. return out[0]
  160. net = Net()
  161. with pytest.raises(ValueError):
  162. _cell_graph_executor.compile(net, _x1)
  163. def test_neighborexchangev2_invalid_send_rank_ids_size_failed():
  164. """
  165. Feature: NeighborExchangeV2
  166. Description: send_rank_ids size should be 8, but gives 5
  167. Expectation: throw ValueError
  168. """
  169. context.set_auto_parallel_context(device_num=8, global_rank=0)
  170. class Net(nn.Cell):
  171. def __init__(self):
  172. super(Net, self).__init__()
  173. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1],
  174. send_lens=[0, 1, 0, 0],
  175. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  176. recv_lens=[0, 1, 0, 0],
  177. data_format="NCHW")
  178. def construct(self, x):
  179. out = self.neighborexchangev2(x)
  180. return out[0]
  181. net = Net()
  182. with pytest.raises(ValueError):
  183. _cell_graph_executor.compile(net, _x1)
  184. def test_neighborexchangev2_invalid_recv_rank_ids_size_failed():
  185. """
  186. Feature: NeighborExchangeV2
  187. Description: recv_rank_ids size should be 8, but gives 5
  188. Expectation: throw ValueError
  189. """
  190. context.set_auto_parallel_context(device_num=8, global_rank=0)
  191. class Net(nn.Cell):
  192. def __init__(self):
  193. super(Net, self).__init__()
  194. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  195. send_lens=[0, 1, 0, 0],
  196. recv_rank_ids=[-1, -1, -1, -1, 1],
  197. recv_lens=[0, 1, 0, 0],
  198. data_format="NCHW")
  199. def construct(self, x):
  200. out = self.neighborexchangev2(x)
  201. return out[0]
  202. net = Net()
  203. with pytest.raises(ValueError):
  204. _cell_graph_executor.compile(net, _x1)
  205. def test_neighborexchangev2_invalid_send_lens_size_failed():
  206. """
  207. Feature: NeighborExchangeV2
  208. Description: send_lens size should be 4, but gives 5
  209. Expectation: throw ValueError
  210. """
  211. context.set_auto_parallel_context(device_num=8, global_rank=0)
  212. class Net(nn.Cell):
  213. def __init__(self):
  214. super(Net, self).__init__()
  215. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  216. send_lens=[0, 1, 0, 0, 2],
  217. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  218. recv_lens=[0, 1, 0, 0],
  219. data_format="NCHW")
  220. def construct(self, x):
  221. out = self.neighborexchangev2(x)
  222. return out[0]
  223. net = Net()
  224. with pytest.raises(ValueError):
  225. _cell_graph_executor.compile(net, _x1)
  226. def test_neighborexchangev2_invalid_recv_lens_size_failed():
  227. """
  228. Feature: NeighborExchangeV2
  229. Description: recv_lens size should be 4, but gives 5
  230. Expectation: throw ValueError
  231. """
  232. context.set_auto_parallel_context(device_num=8, global_rank=0)
  233. class Net(nn.Cell):
  234. def __init__(self):
  235. super(Net, self).__init__()
  236. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  237. send_lens=[0, 1, 0, 0],
  238. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  239. recv_lens=[0, 1, 0, 0, 2],
  240. data_format="NCHW")
  241. def construct(self, x):
  242. out = self.neighborexchangev2(x)
  243. return out[0]
  244. net = Net()
  245. with pytest.raises(ValueError):
  246. _cell_graph_executor.compile(net, _x1)
  247. def test_neighborexchangev2_invalid_input_size_failed():
  248. """
  249. Feature: NeighborExchangeV2
  250. Description: input should be one tensor, but gives 2
  251. Expectation: throw ValueError
  252. """
  253. context.set_auto_parallel_context(device_num=8, global_rank=0)
  254. class Net(nn.Cell):
  255. def __init__(self):
  256. super(Net, self).__init__()
  257. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  258. send_lens=[0, 1, 0, 0],
  259. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  260. recv_lens=[0, 1, 0, 0],
  261. data_format="NCHW")
  262. def construct(self, x1, x2):
  263. out = self.neighborexchangev2(x1, x2)
  264. return out[0]
  265. net = Net()
  266. with pytest.raises(ValueError):
  267. _cell_graph_executor.compile(net, _x1, _x2)
  268. def test_neighborexchangev2_recv_rank_ids_invalid_value_failed():
  269. """
  270. Feature: NeighborExchangeV2
  271. Description: recv_rank_ids should can be concat, recv_rank_ids[3] and [4] is 1, [5] is -1 given
  272. Expectation: throw Exception
  273. """
  274. context.set_auto_parallel_context(device_num=8, global_rank=0)
  275. class Net(nn.Cell):
  276. def __init__(self):
  277. super(Net, self).__init__()
  278. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  279. send_lens=[0, 1, 0, 0],
  280. recv_rank_ids=[-1, -1, -1, 1, 1, -1, -1, -1],
  281. recv_lens=[0, 1, 0, 0],
  282. data_format="NCHW")
  283. def construct(self, x):
  284. out = self.neighborexchangev2(x)
  285. return out[0]
  286. net = Net()
  287. with pytest.raises(ValueError):
  288. _cell_graph_executor.compile(net, _x1)
  289. def test_neighborexchangev2_attr_check_send_rank_ids_is_tuple_failed():
  290. """
  291. Feature: NeighborExchangeV2
  292. Description: send_rank_ids should be list, but a tuple is given
  293. Expectation: throw TypeError
  294. """
  295. context.set_auto_parallel_context(device_num=8, global_rank=0)
  296. class Net(nn.Cell):
  297. def __init__(self):
  298. super(Net, self).__init__()
  299. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
  300. send_lens=[0, 1, 0, 0],
  301. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  302. recv_lens=[0, 1, 0, 0],
  303. data_format="NCHW")
  304. def construct(self, x):
  305. out = self.neighborexchangev2(x)
  306. return out[0]
  307. net = Net()
  308. with pytest.raises(TypeError):
  309. _cell_graph_executor.compile(net, _x1)
  310. def test_neighborexchangev2_attr_check_send_lens_is_tuple_failed():
  311. """
  312. Feature: NeighborExchangeV2
  313. Description: send_lens should be list, but a tuple is given
  314. Expectation: throw TypeError
  315. """
  316. context.set_auto_parallel_context(device_num=8, global_rank=0)
  317. class Net(nn.Cell):
  318. def __init__(self):
  319. super(Net, self).__init__()
  320. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  321. send_lens=(0, 1, 0, 0),
  322. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  323. recv_lens=[0, 1, 0, 0],
  324. data_format="NCHW")
  325. def construct(self, x):
  326. out = self.neighborexchangev2(x)
  327. return out[0]
  328. net = Net()
  329. with pytest.raises(TypeError):
  330. _cell_graph_executor.compile(net, _x1)
  331. def test_neighborexchangev2_attr_check_recv_rank_ids_is_tuple_failed():
  332. """
  333. Feature: NeighborExchangeV2
  334. Description: recv_rank_ids should be list, but a tuple is given
  335. Expectation: throw TypeError
  336. """
  337. context.set_auto_parallel_context(device_num=8, global_rank=0)
  338. class Net(nn.Cell):
  339. def __init__(self):
  340. super(Net, self).__init__()
  341. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  342. send_lens=[0, 1, 0, 0],
  343. recv_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
  344. recv_lens=[0, 1, 0, 0],
  345. data_format="NCHW")
  346. def construct(self, x):
  347. out = self.neighborexchangev2(x)
  348. return out[0]
  349. net = Net()
  350. with pytest.raises(TypeError):
  351. _cell_graph_executor.compile(net, _x1)
  352. def test_neighborexchangev2_attr_check_recv_lens_is_tuple_failed():
  353. """
  354. Feature: NeighborExchangeV2
  355. Description: recv_lens should be list, but a tuple is given
  356. Expectation: throw TypeError
  357. """
  358. context.set_auto_parallel_context(device_num=8, global_rank=0)
  359. class Net(nn.Cell):
  360. def __init__(self):
  361. super(Net, self).__init__()
  362. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  363. send_lens=[0, 1, 0, 0],
  364. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  365. recv_lens=(0, 1, 0, 0),
  366. data_format="NCHW")
  367. def construct(self, x):
  368. out = self.neighborexchangev2(x)
  369. return out[0]
  370. net = Net()
  371. with pytest.raises(TypeError):
  372. _cell_graph_executor.compile(net, _x1)
  373. def test_neighborexchangev2_attr_check_send_rank_ids_is_float_failed():
  374. """
  375. Feature: NeighborExchangeV2
  376. Description: send_rank_ids should be int, but float is given
  377. Expectation: throw TypeError
  378. """
  379. context.set_auto_parallel_context(device_num=8, global_rank=0)
  380. class Net(nn.Cell):
  381. def __init__(self):
  382. super(Net, self).__init__()
  383. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
  384. send_lens=[0, 1, 0, 0],
  385. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  386. recv_lens=[0, 1, 0, 0],
  387. data_format="NCHW")
  388. def construct(self, x):
  389. out = self.neighborexchangev2(x)
  390. return out[0]
  391. net = Net()
  392. with pytest.raises(TypeError):
  393. _cell_graph_executor.compile(net, _x1)
  394. def test_neighborexchangev2_attr_check_send_lens_is_float_failed():
  395. """
  396. Feature: NeighborExchangeV2
  397. Description: send_lens should be int, but float is given
  398. Expectation: throw TypeError
  399. """
  400. context.set_auto_parallel_context(device_num=8, global_rank=0)
  401. class Net(nn.Cell):
  402. def __init__(self):
  403. super(Net, self).__init__()
  404. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  405. send_lens=[0, 1.0, 0, 0],
  406. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  407. recv_lens=[0, 1, 0, 0],
  408. data_format="NCHW")
  409. def construct(self, x):
  410. out = self.neighborexchangev2(x)
  411. return out[0]
  412. net = Net()
  413. with pytest.raises(TypeError):
  414. _cell_graph_executor.compile(net, _x1)
  415. def test_neighborexchangev2_attr_check_recv_rank_ids_is_float_failed():
  416. """
  417. Feature: NeighborExchangeV2
  418. Description: send_rank_ids should be int, but float is given
  419. Expectation: throw TypeError
  420. """
  421. context.set_auto_parallel_context(device_num=8, global_rank=0)
  422. class Net(nn.Cell):
  423. def __init__(self):
  424. super(Net, self).__init__()
  425. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  426. send_lens=[0, 1, 0, 0],
  427. recv_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
  428. recv_lens=[0, 1, 0, 0],
  429. data_format="NCHW")
  430. def construct(self, x):
  431. out = self.neighborexchangev2(x)
  432. return out[0]
  433. net = Net()
  434. with pytest.raises(TypeError):
  435. _cell_graph_executor.compile(net, _x1)
  436. def test_neighborexchangev2_attr_check_recv_lens_is_float_failed():
  437. """
  438. Feature: NeighborExchangeV2
  439. Description: ids in send_rank_ids should be int, but float is given
  440. Expectation: throw TypeError
  441. """
  442. context.set_auto_parallel_context(device_num=8, global_rank=0)
  443. class Net(nn.Cell):
  444. def __init__(self):
  445. super(Net, self).__init__()
  446. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  447. send_lens=[0, 1, 0, 0],
  448. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  449. recv_lens=[0, 1.0, 0, 0],
  450. data_format="NCHW")
  451. def construct(self, x):
  452. out = self.neighborexchangev2(x)
  453. return out[0]
  454. net = Net()
  455. with pytest.raises(TypeError):
  456. _cell_graph_executor.compile(net, _x1)
  457. def test_neighborexchangev2_group_is_tuple_failed():
  458. """
  459. Feature: NeighborExchangeV2
  460. Description: group should be a string, but tuple given
  461. Expectation: throw TypeError
  462. """
  463. context.set_auto_parallel_context(device_num=8, global_rank=0)
  464. class Net(nn.Cell):
  465. def __init__(self):
  466. super(Net, self).__init__()
  467. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  468. send_lens=[0, 1, 0, 0],
  469. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  470. recv_lens=[0, 1, 0, 0],
  471. data_format="NCHW", group=("str",))
  472. def construct(self, x):
  473. out = self.neighborexchangev2(x)
  474. return out[0]
  475. net = Net()
  476. with pytest.raises(TypeError):
  477. _cell_graph_executor.compile(net, _x1)
  478. def test_neighborexchangev2_send_lens_larger_than_input_shape_failed():
  479. """
  480. Feature: NeighborExchangeV2
  481. Description: send_lens should be <= input_shape, but a larger one given
  482. Expectation: throw TypeError
  483. """
  484. context.set_auto_parallel_context(device_num=8, global_rank=0)
  485. class Net(nn.Cell):
  486. def __init__(self):
  487. super(Net, self).__init__()
  488. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  489. send_lens=[0, 35, 0, 0],
  490. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  491. recv_lens=[0, 1, 0, 0],
  492. data_format="NCHW")
  493. def construct(self, x):
  494. out = self.neighborexchangev2(x)
  495. return out[0]
  496. net = Net()
  497. with pytest.raises(ValueError):
  498. _cell_graph_executor.compile(net, _x1)
  499. def test_neighborexchangev2_send_rank_ids_value_invalid_failed():
  500. """
  501. Feature: NeighborExchangeV2
  502. Description: send_rank_ids should be >=0 or -1, but -3 is given
  503. Expectation: throw TypeError
  504. """
  505. context.set_auto_parallel_context(device_num=8, global_rank=0)
  506. class Net(nn.Cell):
  507. def __init__(self):
  508. super(Net, self).__init__()
  509. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1],
  510. send_lens=[0, 1, 0, 0],
  511. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  512. recv_lens=[0, 1, 0, 0],
  513. data_format="NCHW")
  514. def construct(self, x):
  515. out = self.neighborexchangev2(x)
  516. return out[0]
  517. net = Net()
  518. with pytest.raises(ValueError):
  519. _cell_graph_executor.compile(net, _x1)
  520. def test_neighborexchangev2_recv_rank_ids_value_invalid_failed():
  521. """
  522. Feature: NeighborExchangeV2
  523. Description: recv_rank_ids should be >=0 or -1, but -3 is given
  524. Expectation: throw TypeError
  525. """
  526. context.set_auto_parallel_context(device_num=8, global_rank=0)
  527. class Net(nn.Cell):
  528. def __init__(self):
  529. super(Net, self).__init__()
  530. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  531. send_lens=[0, 1, 0, 0],
  532. recv_rank_ids=[-1, -1, -1, -3, 1, -1, -1, -1],
  533. recv_lens=[0, 1, 0, 0],
  534. data_format="NCHW")
  535. def construct(self, x):
  536. out = self.neighborexchangev2(x)
  537. return out[0]
  538. net = Net()
  539. with pytest.raises(ValueError):
  540. _cell_graph_executor.compile(net, _x1)
  541. def test_neighborexchangev2_send_lens_value_invalid_failed():
  542. """
  543. Feature: NeighborExchangeV2
  544. Description: send_lens should be >=0, but -3 is given
  545. Expectation: throw TypeError
  546. """
  547. context.set_auto_parallel_context(device_num=8, global_rank=0)
  548. class Net(nn.Cell):
  549. def __init__(self):
  550. super(Net, self).__init__()
  551. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  552. send_lens=[0, -3, 0, 0],
  553. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  554. recv_lens=[0, 1, 0, 0],
  555. data_format="NCHW")
  556. def construct(self, x):
  557. out = self.neighborexchangev2(x)
  558. return out[0]
  559. net = Net()
  560. with pytest.raises(ValueError):
  561. _cell_graph_executor.compile(net, _x1)
  562. def test_neighborexchangev2_recv_lens_value_invalid_failed():
  563. """
  564. Feature: NeighborExchangeV2
  565. Description: recv_lens should be >=0, but -3 is given
  566. Expectation: throw TypeError
  567. """
  568. context.set_auto_parallel_context(device_num=8, global_rank=0)
  569. class Net(nn.Cell):
  570. def __init__(self):
  571. super(Net, self).__init__()
  572. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  573. send_lens=[0, 1, 0, 0],
  574. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  575. recv_lens=[0, -3, 0, 0],
  576. data_format="NCHW")
  577. def construct(self, x):
  578. out = self.neighborexchangev2(x)
  579. return out[0]
  580. net = Net()
  581. with pytest.raises(ValueError):
  582. _cell_graph_executor.compile(net, _x1)
  583. def test_neighborexchangev2_send_rank_ids_repeat_failed():
  584. """
  585. Feature: NeighborExchangeV2
  586. Description: send_rank_ids cannot be repeated, but two 1 is given
  587. Expectation: throw TypeError
  588. """
  589. context.set_auto_parallel_context(device_num=8, global_rank=0)
  590. class Net(nn.Cell):
  591. def __init__(self):
  592. super(Net, self).__init__()
  593. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1],
  594. send_lens=[0, 1, 0, 0],
  595. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  596. recv_lens=[0, 1, 0, 0],
  597. data_format="NCHW")
  598. def construct(self, x):
  599. out = self.neighborexchangev2(x)
  600. return out[0]
  601. net = Net()
  602. with pytest.raises(ValueError):
  603. _cell_graph_executor.compile(net, _x1)
  604. def test_neighborexchangev2_recv_rank_ids_repeat_failed():
  605. """
  606. Feature: NeighborExchangeV2
  607. Description: recv_rank_ids cannot be repeated, but two 1 is given
  608. Expectation: throw TypeError
  609. """
  610. context.set_auto_parallel_context(device_num=8, global_rank=0)
  611. class Net(nn.Cell):
  612. def __init__(self):
  613. super(Net, self).__init__()
  614. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  615. send_lens=[0, 1, 0, 0],
  616. recv_rank_ids=[1, -1, -1, -1, 1, -1, -1, -1],
  617. recv_lens=[0, 1, 0, 0],
  618. data_format="NCHW")
  619. def construct(self, x):
  620. out = self.neighborexchangev2(x)
  621. return out[0]
  622. net = Net()
  623. with pytest.raises(ValueError):
  624. _cell_graph_executor.compile(net, _x1)