You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_neighborexchange.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore as ms
  18. import mindspore.context as context
  19. from mindspore import Tensor, Parameter
  20. import mindspore.nn as nn
  21. from mindspore.common.api import _cell_graph_executor
  22. from mindspore.nn import TrainOneStepCell, Momentum
  23. from mindspore.ops import operations as P
  24. from mindspore.ops.operations.comm_ops import NeighborExchange
  25. _w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  26. _x1 = Tensor(np.ones([32, 16]), dtype=ms.float32)
  27. _x2 = Tensor(np.ones([16, 32]), dtype=ms.float32)
  28. def compile_net(net):
  29. context.set_context(mode=context.GRAPH_MODE)
  30. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  31. train_net = TrainOneStepCell(net, optimizer)
  32. train_net.set_train()
  33. _cell_graph_executor.compile(train_net, _x1, _x2)
  34. def test_NeighborExchange_two_inputs_success():
  35. """
  36. Feature: NeighborExchange
  37. Description: two inputs and two outputs, with valid arguments
  38. Expectation: success
  39. """
  40. context.set_auto_parallel_context(device_num=8, global_rank=0)
  41. class MatMulNet(nn.Cell):
  42. def __init__(self, weight1):
  43. super(MatMulNet, self).__init__()
  44. self.matmul = P.MatMul()
  45. self.mul = P.Mul()
  46. self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2],
  47. recv_shapes=([32, 32], [32, 64]),
  48. send_shapes=([32, 32], [32, 16]), recv_type=ms.float32)
  49. self.weight1 = Parameter(weight1, "w1")
  50. def construct(self, x1, x2):
  51. out = self.matmul(x1, x2)
  52. out = self.mul(out, self.weight1)
  53. out = self.alltoallv((out, x1))
  54. return out[0]
  55. net = MatMulNet(_w1)
  56. compile_net(net)
  57. def test_NeighborExchange_single_input_success():
  58. """
  59. Feature: NeighborExchange
  60. Description: one inputs and two outputs, with valid arguments
  61. Expectation: success
  62. """
  63. context.set_auto_parallel_context(device_num=8, global_rank=0)
  64. class MatMulNet2(nn.Cell):
  65. def __init__(self, weight1):
  66. super(MatMulNet2, self).__init__()
  67. self.matmul = P.MatMul()
  68. self.mul = P.Mul()
  69. self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
  70. send_shapes=([32, 32],), recv_type=ms.float32)
  71. self.weight1 = Parameter(weight1, "w1")
  72. def construct(self, x1, x2):
  73. out = self.matmul(x1, x2)
  74. out = self.mul(out, self.weight1)
  75. out = self.alltoallv((out,))
  76. return out[0]
  77. net = MatMulNet2(_w1)
  78. compile_net(net)
  79. def test_NeighborExchange_empty_send_success():
  80. """
  81. Feature: NeighborExchange
  82. Description: empty inputs, with valid arguments
  83. Expectation: success
  84. """
  85. context.set_auto_parallel_context(device_num=8, global_rank=0)
  86. class Net(nn.Cell):
  87. def __init__(self):
  88. super(Net, self).__init__()
  89. self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[1], recv_shapes=([1],),
  90. send_shapes=(), recv_type=ms.float32)
  91. def construct(self, x1):
  92. self.alltoallv()
  93. return x1
  94. net = Net()
  95. _cell_graph_executor.compile(net, _x1)
  96. def test_NeighborExchange_empty_recv_success():
  97. """
  98. Feature: NeighborExchange
  99. Description: empty outputs, with valid arguments
  100. Expectation: success
  101. """
  102. context.set_auto_parallel_context(device_num=8, global_rank=0)
  103. class Net(nn.Cell):
  104. def __init__(self):
  105. super(Net, self).__init__()
  106. self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[], recv_shapes=(),
  107. send_shapes=([32, 16],), recv_type=ms.float32)
  108. def construct(self, x1):
  109. self.alltoallv((x1,))
  110. return x1
  111. net = Net()
  112. _cell_graph_executor.compile(net, _x1)
  113. def test_NeighborExchange_empty_send_empty_recv_success():
  114. """
  115. Feature: NeighborExchange
  116. Description: empty inputs and empty outputs, with valid arguments
  117. Expectation: success
  118. """
  119. context.set_auto_parallel_context(device_num=8, global_rank=0)
  120. class Net(nn.Cell):
  121. def __init__(self):
  122. super(Net, self).__init__()
  123. self.alltoallv = NeighborExchange(send_rank_ids=[], recv_rank_ids=[], recv_shapes=(),
  124. send_shapes=(), recv_type=ms.float32)
  125. def construct(self, x1):
  126. self.alltoallv()
  127. return x1
  128. net = Net()
  129. _cell_graph_executor.compile(net, _x1)
  130. def test_NeighborExchange_recv_shape_num_diff_with_recv_rank_size_failed():
  131. """
  132. Feature: NeighborExchange
  133. Description: send_rank_ids and send_shapes are set as 1 input, but gives 2
  134. Expectation: throw ValueError
  135. """
  136. context.set_auto_parallel_context(device_num=8, global_rank=0)
  137. class Net(nn.Cell):
  138. def __init__(self, weight1):
  139. super(Net, self).__init__()
  140. self.matmul = P.MatMul()
  141. self.mul = P.Mul()
  142. self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32],),
  143. send_shapes=([32, 32],), recv_type=ms.float32)
  144. self.weight1 = Parameter(weight1, "w1")
  145. def construct(self, x1, x2):
  146. out = self.matmul(x1, x2)
  147. out = self.mul(out, self.weight1)
  148. out = self.alltoallv((out,))
  149. return out[0]
  150. net = Net(_w1)
  151. with pytest.raises(ValueError):
  152. compile_net(net)
  153. def test_NeighborExchange_send_shape_num_diff_with_send_rank_size_failed():
  154. """
  155. Feature: NeighborExchange
  156. Description: send_rank_ids is set as 2 inputs, but send_shapes are set as 1 input
  157. Expectation: throw ValueError
  158. """
  159. context.set_auto_parallel_context(device_num=8, global_rank=0)
  160. class Net(nn.Cell):
  161. def __init__(self, weight1):
  162. super(Net, self).__init__()
  163. self.matmul = P.MatMul()
  164. self.mul = P.Mul()
  165. self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2],
  166. recv_shapes=([32, 32], [32, 32]),
  167. send_shapes=([32, 32],), recv_type=ms.float32)
  168. self.weight1 = Parameter(weight1, "w1")
  169. def construct(self, x1, x2):
  170. out = self.matmul(x1, x2)
  171. out = self.mul(out, self.weight1)
  172. out = self.alltoallv((out,))
  173. return out[0]
  174. net = Net(_w1)
  175. with pytest.raises(ValueError):
  176. compile_net(net)
  177. def test_NeighborExchange_send_shape_num_diff_with_input_num_failed():
  178. """
  179. Feature: NeighborExchange
  180. Description: send_rank_ids and send_shapes are set as 2 inputs, but has only 1 input
  181. Expectation: throw Exception
  182. """
  183. context.set_auto_parallel_context(device_num=8, global_rank=0)
  184. class Net(nn.Cell):
  185. def __init__(self, weight1):
  186. super(Net, self).__init__()
  187. self.matmul = P.MatMul()
  188. self.mul = P.Mul()
  189. self.alltoallv = NeighborExchange(send_rank_ids=[0, 1], recv_rank_ids=[1, 2],
  190. recv_shapes=([32, 32], [32, 32]),
  191. send_shapes=([32, 32], [32, 32]), recv_type=ms.float32)
  192. self.weight1 = Parameter(weight1, "w1")
  193. def construct(self, x1, x2):
  194. out = self.matmul(x1, x2)
  195. out = self.mul(out, self.weight1)
  196. out = self.alltoallv((out,))
  197. return out[0]
  198. net = Net(_w1)
  199. with pytest.raises(Exception):
  200. compile_net(net)
  201. def test_NeighborExchange_send_shape_diff_with_input_shape_failed():
  202. """
  203. Feature: NeighborExchange
  204. Description: send_shapes is set as [16, 16], but input is [32, 32]
  205. Expectation: throw Exception
  206. """
  207. context.set_auto_parallel_context(device_num=8, global_rank=0)
  208. class Net(nn.Cell):
  209. def __init__(self, weight1):
  210. super(Net, self).__init__()
  211. self.matmul = P.MatMul()
  212. self.mul = P.Mul()
  213. self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
  214. send_shapes=([16, 16],), recv_type=ms.float32)
  215. self.weight1 = Parameter(weight1, "w1")
  216. def construct(self, x1, x2):
  217. out = self.matmul(x1, x2)
  218. out = self.mul(out, self.weight1)
  219. out = self.alltoallv((out,))
  220. return out[0]
  221. net = Net(_w1)
  222. with pytest.raises(Exception):
  223. compile_net(net)
  224. def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_failed():
  225. """
  226. Feature: NeighborExchange
  227. Description: send_rank_ids should be list, but a tuple is given
  228. Expectation: throw TypeError
  229. """
  230. context.set_auto_parallel_context(device_num=8, global_rank=0)
  231. class Net(nn.Cell):
  232. def __init__(self):
  233. super(Net, self).__init__()
  234. self.alltoallv = NeighborExchange(send_rank_ids=(0), recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
  235. send_shapes=([32, 16],), recv_type=ms.float32)
  236. def construct(self, x1):
  237. out = self.alltoallv((x1,))
  238. return out[0]
  239. net = Net()
  240. with pytest.raises(TypeError):
  241. _cell_graph_executor.compile(net, _x1)
  242. def test_NeighborExchange_attr_check_send_rank_ids_is_tuple_2_failed():
  243. """
  244. Feature: NeighborExchange
  245. Description: send_rank_ids should be list, but a tuple is given
  246. Expectation: throw TypeError
  247. """
  248. context.set_auto_parallel_context(device_num=8, global_rank=0)
  249. class Net(nn.Cell):
  250. def __init__(self):
  251. super(Net, self).__init__()
  252. self.alltoallv = NeighborExchange(send_rank_ids=(0,), recv_rank_ids=[1, 2],
  253. recv_shapes=([32, 32], [32, 64]),
  254. send_shapes=([32, 16],), recv_type=ms.float32)
  255. def construct(self, x1):
  256. out = self.alltoallv((x1,))
  257. return out[0]
  258. net = Net()
  259. with pytest.raises(TypeError):
  260. _cell_graph_executor.compile(net, _x1)
  261. def test_NeighborExchange_attr_check_send_rank_ids_is_float_failed():
  262. """
  263. Feature: NeighborExchange
  264. Description: send_rank_ids should be int, but a float is given
  265. Expectation: throw TypeError
  266. """
  267. context.set_auto_parallel_context(device_num=8, global_rank=0)
  268. class Net(nn.Cell):
  269. def __init__(self):
  270. super(Net, self).__init__()
  271. self.alltoallv = NeighborExchange(send_rank_ids=[1.0], recv_rank_ids=[1, 2],
  272. recv_shapes=([32, 32], [32, 64]),
  273. send_shapes=([32, 16],), recv_type=ms.float32)
  274. def construct(self, x1):
  275. out = self.alltoallv((x1,))
  276. return out[0]
  277. net = Net()
  278. with pytest.raises(TypeError):
  279. _cell_graph_executor.compile(net, _x1)
  280. def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_failed():
  281. """
  282. Feature: NeighborExchange
  283. Description: recv_rank_ids should be list, but a tuple is given
  284. Expectation: throw TypeError
  285. """
  286. context.set_auto_parallel_context(device_num=8, global_rank=0)
  287. class Net(nn.Cell):
  288. def __init__(self):
  289. super(Net, self).__init__()
  290. self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=([1, 2],),
  291. recv_shapes=([32, 32], [32, 64]),
  292. send_shapes=([32, 16],), recv_type=ms.float32)
  293. def construct(self, x1):
  294. out = self.alltoallv((x1,))
  295. return out[0]
  296. net = Net()
  297. with pytest.raises(TypeError):
  298. _cell_graph_executor.compile(net, _x1)
  299. def test_NeighborExchange_attr_check_recv_rank_ids_is_tuple_2_failed():
  300. """
  301. Feature: NeighborExchange
  302. Description: recv_rank_ids should be list, but a tuple is given
  303. Expectation: throw TypeError
  304. """
  305. context.set_auto_parallel_context(device_num=8, global_rank=0)
  306. class Net(nn.Cell):
  307. def __init__(self):
  308. super(Net, self).__init__()
  309. self.alltoallv = NeighborExchange(send_rank_ids=[0], recv_rank_ids=(1, 2,),
  310. recv_shapes=([32, 32], [32, 64]),
  311. send_shapes=([32, 16],), recv_type=ms.float32)
  312. def construct(self, x1):
  313. out = self.alltoallv((x1,))
  314. return out[0]
  315. net = Net()
  316. with pytest.raises(TypeError):
  317. _cell_graph_executor.compile(net, _x1)
  318. def test_NeighborExchange_attr_check_recv_rank_ids_is_float_failed():
  319. """
  320. Feature: NeighborExchange
  321. Description: recv_rank_ids should be int, but a float is given
  322. Expectation: throw TypeError
  323. """
  324. context.set_auto_parallel_context(device_num=8, global_rank=0)
  325. class Net(nn.Cell):
  326. def __init__(self):
  327. super(Net, self).__init__()
  328. self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2.0],
  329. recv_shapes=([32, 32], [32, 64]),
  330. send_shapes=([32, 16],), recv_type=ms.float32)
  331. def construct(self, x1):
  332. out = self.alltoallv((x1,))
  333. return out[0]
  334. net = Net()
  335. with pytest.raises(TypeError):
  336. _cell_graph_executor.compile(net, _x1)
  337. def test_NeighborExchange_attr_check_send_shape_not_tuple_failed():
  338. """
  339. Feature: NeighborExchange
  340. Description: send_shapes should be tuple(list), but a list is given
  341. Expectation: throw TypeError
  342. """
  343. context.set_auto_parallel_context(device_num=8, global_rank=0)
  344. class Net(nn.Cell):
  345. def __init__(self):
  346. super(Net, self).__init__()
  347. self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
  348. recv_shapes=([32, 32], [32, 64]),
  349. send_shapes=([32, 16]), recv_type=ms.float32)
  350. def construct(self, x1):
  351. out = self.alltoallv((x1,))
  352. return out[0]
  353. net = Net()
  354. with pytest.raises(TypeError):
  355. _cell_graph_executor.compile(net, _x1)
  356. def test_NeighborExchange_attr_check_send_shape_list_failed():
  357. """
  358. Feature: NeighborExchange
  359. Description: send_shapes should be tuple(list), but a list(list) is given
  360. Expectation: throw TypeError
  361. """
  362. context.set_auto_parallel_context(device_num=8, global_rank=0)
  363. class Net(nn.Cell):
  364. def __init__(self):
  365. super(Net, self).__init__()
  366. self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
  367. recv_shapes=([32, 32], [32, 64]),
  368. send_shapes=[[32, 16]], recv_type=ms.float32)
  369. def construct(self, x1):
  370. out = self.alltoallv((x1,))
  371. return out[0]
  372. net = Net()
  373. with pytest.raises(TypeError):
  374. _cell_graph_executor.compile(net, _x1)
  375. def test_NeighborExchange_attr_check_recv_type_numpy_failed():
  376. """
  377. Feature: NeighborExchange
  378. Description: recv_type should be mindspore type, but a numpy type is given
  379. Expectation: throw TypeError
  380. """
  381. context.set_auto_parallel_context(device_num=8, global_rank=0)
  382. class Net(nn.Cell):
  383. def __init__(self):
  384. super(Net, self).__init__()
  385. self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
  386. recv_shapes=([32, 32], [32, 64]),
  387. send_shapes=([32, 16],), recv_type=np.float32)
  388. def construct(self, x1):
  389. out = self.alltoallv((x1,))
  390. return out[0]
  391. net = Net()
  392. with pytest.raises(TypeError):
  393. _cell_graph_executor.compile(net, _x1)
  394. def test_NeighborExchange_attr_invalid_grpup_failed():
  395. """
  396. Feature: NeighborExchange
  397. Description: group should be str, but a tuple is given
  398. Expectation: throw TypeError
  399. """
  400. context.set_auto_parallel_context(device_num=8, global_rank=0)
  401. class Net(nn.Cell):
  402. def __init__(self):
  403. super(Net, self).__init__()
  404. self.alltoallv = NeighborExchange(send_rank_ids=[1], recv_rank_ids=[1, 2],
  405. recv_shapes=([32, 32], [32, 64]),
  406. send_shapes=([32, 16],), recv_type=ms.float32, group=("str",))
  407. def construct(self, x1):
  408. out = self.alltoallv((x1,))
  409. return out[0]
  410. net = Net()
  411. with pytest.raises(TypeError):
  412. _cell_graph_executor.compile(net, _x1)