You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_neighborexchangev2.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore as ms
  18. import mindspore.context as context
  19. from mindspore import Tensor
  20. import mindspore.nn as nn
  21. from mindspore.common.api import _cell_graph_executor
  22. from mindspore.nn import TrainOneStepCell, Momentum
  23. from mindspore.ops.operations.comm_ops import NeighborExchangeV2
  24. _x1 = Tensor(np.ones([1, 1, 32, 16]), dtype=ms.float32)
  25. _x2 = Tensor(np.ones([1, 1, 33, 16]), dtype=ms.float32)
  26. def compile_net(net, x1, x2):
  27. context.set_context(mode=context.GRAPH_MODE)
  28. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  29. train_net = TrainOneStepCell(net, optimizer)
  30. train_net.set_train()
  31. _cell_graph_executor.compile(train_net, x1, x2)
  32. def test_neighborexchangev2_single_input_success():
  33. """
  34. Feature: NeighborExchangeV2
  35. Description: one inputs and one outputs, with valid arguments
  36. Expectation: success
  37. """
  38. context.set_auto_parallel_context(device_num=8, global_rank=0)
  39. class Net(nn.Cell):
  40. def __init__(self):
  41. super(Net, self).__init__()
  42. self.linear = nn.Dense(16, 16)
  43. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  44. send_lens=[0, 1, 0, 0],
  45. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  46. recv_lens=[0, 1, 0, 0], data_format="NCHW")
  47. def construct(self, x1, x2):
  48. y = self.linear(x1)
  49. y = self.neighborexchangev2(y)
  50. y = y + x2
  51. return y
  52. net = Net()
  53. compile_net(net, _x1, _x2)
  54. def test_neighborexchangev2_empty_send_success():
  55. """
  56. Feature: NeighborExchangeV2
  57. Description: empty inputs, with valid arguments
  58. Expectation: success
  59. """
  60. context.set_auto_parallel_context(device_num=8, global_rank=0)
  61. class Net(nn.Cell):
  62. def __init__(self):
  63. super(Net, self).__init__()
  64. self.linear = nn.Dense(16, 16)
  65. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  66. send_lens=[1, 2, 3, 4],
  67. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  68. recv_lens=[0, 1, 0, 0],
  69. data_format="NCHW")
  70. def construct(self, x1, x2):
  71. y = self.linear(x1)
  72. y = self.neighborexchangev2(y)
  73. y = y + x2
  74. return y
  75. net = Net()
  76. compile_net(net, _x1, _x2)
  77. def test_neighborexchangev2_empty_recv_success():
  78. """
  79. Feature: NeighborExchangeV2
  80. Description: empty outputs, with valid arguments
  81. Expectation: success
  82. """
  83. context.set_auto_parallel_context(device_num=8, global_rank=0)
  84. class Net(nn.Cell):
  85. def __init__(self):
  86. super(Net, self).__init__()
  87. self.linear = nn.Dense(16, 16)
  88. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  89. send_lens=[0, 1, 0, 0],
  90. recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  91. recv_lens=[1, 2, 3, 4],
  92. data_format="NCHW")
  93. def construct(self, x1, x2):
  94. y = self.linear(x1)
  95. y = self.neighborexchangev2(y)
  96. y = y + x2
  97. return y
  98. net = Net()
  99. compile_net(net, _x1, _x1)
  100. def test_neighborexchangev2_empty_send_empty_recv_success():
  101. """
  102. Feature: NeighborExchangeV2
  103. Description: empty inputs and empty outputs, with valid arguments
  104. Expectation: success
  105. """
  106. context.set_auto_parallel_context(device_num=8, global_rank=0)
  107. class Net(nn.Cell):
  108. def __init__(self):
  109. super(Net, self).__init__()
  110. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  111. send_lens=[0, 1, 0, 0],
  112. recv_rank_ids=[-1, -1, -1, -1, -1, -1, -1, -1],
  113. recv_lens=[1, 2, 3, 4],
  114. data_format="NCHW")
  115. def construct(self, x1):
  116. y = self.neighborexchangev2(x1)
  117. return y
  118. net = Net()
  119. _cell_graph_executor.compile(net, _x1)
  120. def test_neighborexchangev2_invalid_dataformat_failed():
  121. """
  122. Feature: NeighborExchangeV2
  123. Description: data_format should be NCHW, but gives NHWC
  124. Expectation: throw ValueError
  125. """
  126. context.set_auto_parallel_context(device_num=8, global_rank=0)
  127. class Net(nn.Cell):
  128. def __init__(self):
  129. super(Net, self).__init__()
  130. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  131. send_lens=[0, 1, 0, 0],
  132. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  133. recv_lens=[0, 1, 0, 0],
  134. data_format="NHWC")
  135. def construct(self, x):
  136. out = self.neighborexchangev2(x)
  137. return out[0]
  138. net = Net()
  139. with pytest.raises(ValueError):
  140. _cell_graph_executor.compile(net, _x1)
  141. def test_neighborexchangev2_invalid_send_rank_ids_size_failed():
  142. """
  143. Feature: NeighborExchangeV2
  144. Description: send_rank_ids size should be 8, but gives 5
  145. Expectation: throw ValueError
  146. """
  147. context.set_auto_parallel_context(device_num=8, global_rank=0)
  148. class Net(nn.Cell):
  149. def __init__(self):
  150. super(Net, self).__init__()
  151. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1],
  152. send_lens=[0, 1, 0, 0],
  153. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  154. recv_lens=[0, 1, 0, 0],
  155. data_format="NCHW")
  156. def construct(self, x):
  157. out = self.neighborexchangev2(x)
  158. return out[0]
  159. net = Net()
  160. with pytest.raises(ValueError):
  161. _cell_graph_executor.compile(net, _x1)
  162. def test_neighborexchangev2_invalid_recv_rank_ids_size_failed():
  163. """
  164. Feature: NeighborExchangeV2
  165. Description: recv_rank_ids size should be 8, but gives 5
  166. Expectation: throw ValueError
  167. """
  168. context.set_auto_parallel_context(device_num=8, global_rank=0)
  169. class Net(nn.Cell):
  170. def __init__(self):
  171. super(Net, self).__init__()
  172. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  173. send_lens=[0, 1, 0, 0],
  174. recv_rank_ids=[-1, -1, -1, -1, 1],
  175. recv_lens=[0, 1, 0, 0],
  176. data_format="NCHW")
  177. def construct(self, x):
  178. out = self.neighborexchangev2(x)
  179. return out[0]
  180. net = Net()
  181. with pytest.raises(ValueError):
  182. _cell_graph_executor.compile(net, _x1)
  183. def test_neighborexchangev2_invalid_send_lens_size_failed():
  184. """
  185. Feature: NeighborExchangeV2
  186. Description: send_lens size should be 4, but gives 5
  187. Expectation: throw ValueError
  188. """
  189. context.set_auto_parallel_context(device_num=8, global_rank=0)
  190. class Net(nn.Cell):
  191. def __init__(self):
  192. super(Net, self).__init__()
  193. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  194. send_lens=[0, 1, 0, 0, 2],
  195. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  196. recv_lens=[0, 1, 0, 0],
  197. data_format="NCHW")
  198. def construct(self, x):
  199. out = self.neighborexchangev2(x)
  200. return out[0]
  201. net = Net()
  202. with pytest.raises(ValueError):
  203. _cell_graph_executor.compile(net, _x1)
  204. def test_neighborexchangev2_invalid_recv_lens_size_failed():
  205. """
  206. Feature: NeighborExchangeV2
  207. Description: recv_lens size should be 4, but gives 5
  208. Expectation: throw ValueError
  209. """
  210. context.set_auto_parallel_context(device_num=8, global_rank=0)
  211. class Net(nn.Cell):
  212. def __init__(self):
  213. super(Net, self).__init__()
  214. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  215. send_lens=[0, 1, 0, 0],
  216. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  217. recv_lens=[0, 1, 0, 0, 2],
  218. data_format="NCHW")
  219. def construct(self, x):
  220. out = self.neighborexchangev2(x)
  221. return out[0]
  222. net = Net()
  223. with pytest.raises(ValueError):
  224. _cell_graph_executor.compile(net, _x1)
  225. def test_neighborexchangev2_invalid_input_size_failed():
  226. """
  227. Feature: NeighborExchangeV2
  228. Description: input should be one tensor, but gives 2
  229. Expectation: throw ValueError
  230. """
  231. context.set_auto_parallel_context(device_num=8, global_rank=0)
  232. class Net(nn.Cell):
  233. def __init__(self):
  234. super(Net, self).__init__()
  235. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  236. send_lens=[0, 1, 0, 0],
  237. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  238. recv_lens=[0, 1, 0, 0],
  239. data_format="NCHW")
  240. def construct(self, x1, x2):
  241. out = self.neighborexchangev2(x1, x2)
  242. return out[0]
  243. net = Net()
  244. with pytest.raises(ValueError):
  245. _cell_graph_executor.compile(net, _x1, _x2)
  246. def test_neighborexchangev2_recv_rank_ids_invalid_value_failed():
  247. """
  248. Feature: NeighborExchangeV2
  249. Description: recv_rank_ids should can be concat, recv_rank_ids[3] and [4] is 1, [5] is -1 given
  250. Expectation: throw Exception
  251. """
  252. context.set_auto_parallel_context(device_num=8, global_rank=0)
  253. class Net(nn.Cell):
  254. def __init__(self):
  255. super(Net, self).__init__()
  256. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  257. send_lens=[0, 1, 0, 0],
  258. recv_rank_ids=[-1, -1, -1, 1, 1, -1, -1, -1],
  259. recv_lens=[0, 1, 0, 0],
  260. data_format="NCHW")
  261. def construct(self, x):
  262. out = self.neighborexchangev2(x)
  263. return out[0]
  264. net = Net()
  265. with pytest.raises(ValueError):
  266. _cell_graph_executor.compile(net, _x1)
  267. def test_neighborexchangev2_attr_check_send_rank_ids_is_tuple_failed():
  268. """
  269. Feature: NeighborExchangeV2
  270. Description: send_rank_ids should be list, but a tuple is given
  271. Expectation: throw TypeError
  272. """
  273. context.set_auto_parallel_context(device_num=8, global_rank=0)
  274. class Net(nn.Cell):
  275. def __init__(self):
  276. super(Net, self).__init__()
  277. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
  278. send_lens=[0, 1, 0, 0],
  279. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  280. recv_lens=[0, 1, 0, 0],
  281. data_format="NCHW")
  282. def construct(self, x):
  283. out = self.neighborexchangev2(x)
  284. return out[0]
  285. net = Net()
  286. with pytest.raises(TypeError):
  287. _cell_graph_executor.compile(net, _x1)
  288. def test_neighborexchangev2_attr_check_send_lens_is_tuple_failed():
  289. """
  290. Feature: NeighborExchangeV2
  291. Description: send_lens should be list, but a tuple is given
  292. Expectation: throw TypeError
  293. """
  294. context.set_auto_parallel_context(device_num=8, global_rank=0)
  295. class Net(nn.Cell):
  296. def __init__(self):
  297. super(Net, self).__init__()
  298. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  299. send_lens=(0, 1, 0, 0),
  300. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  301. recv_lens=[0, 1, 0, 0],
  302. data_format="NCHW")
  303. def construct(self, x):
  304. out = self.neighborexchangev2(x)
  305. return out[0]
  306. net = Net()
  307. with pytest.raises(TypeError):
  308. _cell_graph_executor.compile(net, _x1)
  309. def test_neighborexchangev2_attr_check_recv_rank_ids_is_tuple_failed():
  310. """
  311. Feature: NeighborExchangeV2
  312. Description: recv_rank_ids should be list, but a tuple is given
  313. Expectation: throw TypeError
  314. """
  315. context.set_auto_parallel_context(device_num=8, global_rank=0)
  316. class Net(nn.Cell):
  317. def __init__(self):
  318. super(Net, self).__init__()
  319. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  320. send_lens=[0, 1, 0, 0],
  321. recv_rank_ids=(-1, -1, -1, -1, 1, -1, -1, -1),
  322. recv_lens=[0, 1, 0, 0],
  323. data_format="NCHW")
  324. def construct(self, x):
  325. out = self.neighborexchangev2(x)
  326. return out[0]
  327. net = Net()
  328. with pytest.raises(TypeError):
  329. _cell_graph_executor.compile(net, _x1)
  330. def test_neighborexchangev2_attr_check_recv_lens_is_tuple_failed():
  331. """
  332. Feature: NeighborExchangeV2
  333. Description: recv_lens should be list, but a tuple is given
  334. Expectation: throw TypeError
  335. """
  336. context.set_auto_parallel_context(device_num=8, global_rank=0)
  337. class Net(nn.Cell):
  338. def __init__(self):
  339. super(Net, self).__init__()
  340. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  341. send_lens=[0, 1, 0, 0],
  342. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  343. recv_lens=(0, 1, 0, 0),
  344. data_format="NCHW")
  345. def construct(self, x):
  346. out = self.neighborexchangev2(x)
  347. return out[0]
  348. net = Net()
  349. with pytest.raises(TypeError):
  350. _cell_graph_executor.compile(net, _x1)
  351. def test_neighborexchangev2_attr_check_send_rank_ids_is_float_failed():
  352. """
  353. Feature: NeighborExchangeV2
  354. Description: send_rank_ids should be int, but float is given
  355. Expectation: throw TypeError
  356. """
  357. context.set_auto_parallel_context(device_num=8, global_rank=0)
  358. class Net(nn.Cell):
  359. def __init__(self):
  360. super(Net, self).__init__()
  361. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
  362. send_lens=[0, 1, 0, 0],
  363. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  364. recv_lens=[0, 1, 0, 0],
  365. data_format="NCHW")
  366. def construct(self, x):
  367. out = self.neighborexchangev2(x)
  368. return out[0]
  369. net = Net()
  370. with pytest.raises(TypeError):
  371. _cell_graph_executor.compile(net, _x1)
  372. def test_neighborexchangev2_attr_check_send_lens_is_float_failed():
  373. """
  374. Feature: NeighborExchangeV2
  375. Description: send_lens should be int, but float is given
  376. Expectation: throw TypeError
  377. """
  378. context.set_auto_parallel_context(device_num=8, global_rank=0)
  379. class Net(nn.Cell):
  380. def __init__(self):
  381. super(Net, self).__init__()
  382. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  383. send_lens=[0, 1.0, 0, 0],
  384. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  385. recv_lens=[0, 1, 0, 0],
  386. data_format="NCHW")
  387. def construct(self, x):
  388. out = self.neighborexchangev2(x)
  389. return out[0]
  390. net = Net()
  391. with pytest.raises(TypeError):
  392. _cell_graph_executor.compile(net, _x1)
  393. def test_neighborexchangev2_attr_check_recv_rank_ids_is_float_failed():
  394. """
  395. Feature: NeighborExchangeV2
  396. Description: send_rank_ids should be int, but float is given
  397. Expectation: throw TypeError
  398. """
  399. context.set_auto_parallel_context(device_num=8, global_rank=0)
  400. class Net(nn.Cell):
  401. def __init__(self):
  402. super(Net, self).__init__()
  403. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  404. send_lens=[0, 1, 0, 0],
  405. recv_rank_ids=[-1, -1, -1, -1, 1.0, -1, -1, -1],
  406. recv_lens=[0, 1, 0, 0],
  407. data_format="NCHW")
  408. def construct(self, x):
  409. out = self.neighborexchangev2(x)
  410. return out[0]
  411. net = Net()
  412. with pytest.raises(TypeError):
  413. _cell_graph_executor.compile(net, _x1)
  414. def test_neighborexchangev2_attr_check_recv_lens_is_float_failed():
  415. """
  416. Feature: NeighborExchangeV2
  417. Description: ids in send_rank_ids should be int, but float is given
  418. Expectation: throw TypeError
  419. """
  420. context.set_auto_parallel_context(device_num=8, global_rank=0)
  421. class Net(nn.Cell):
  422. def __init__(self):
  423. super(Net, self).__init__()
  424. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  425. send_lens=[0, 1, 0, 0],
  426. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  427. recv_lens=[0, 1.0, 0, 0],
  428. data_format="NCHW")
  429. def construct(self, x):
  430. out = self.neighborexchangev2(x)
  431. return out[0]
  432. net = Net()
  433. with pytest.raises(TypeError):
  434. _cell_graph_executor.compile(net, _x1)
  435. def test_neighborexchangev2_group_is_tuple_failed():
  436. """
  437. Feature: NeighborExchangeV2
  438. Description: group should be a string, but tuple given
  439. Expectation: throw TypeError
  440. """
  441. context.set_auto_parallel_context(device_num=8, global_rank=0)
  442. class Net(nn.Cell):
  443. def __init__(self):
  444. super(Net, self).__init__()
  445. self.neighborexchangev2 = NeighborExchangeV2(send_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  446. send_lens=[0, 1, 0, 0],
  447. recv_rank_ids=[-1, -1, -1, -1, 1, -1, -1, -1],
  448. recv_lens=[0, 1, 0, 0],
  449. data_format="NCHW", group=("str",))
  450. def construct(self, x):
  451. out = self.neighborexchangev2(x)
  452. return out[0]
  453. net = Net()
  454. with pytest.raises(TypeError):
  455. _cell_graph_executor.compile(net, _x1)