You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hccl_adapter_test.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "runtime/hccl_adapter/all_to_all_v_calc_param.h"
  19. #include "backend/session/anf_runtime_algorithm.h"
  20. #include "mindspore/core/ir/dtype/type_id.h"
  21. namespace mindspore::hccl {
  22. class TestHcclAdapter : public UT::Common {
  23. public:
  24. TestHcclAdapter() {}
  25. protected:
  26. CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> inputs,
  27. const std::vector<int64_t> &send_rank_ids, const std::vector<int64_t> &recv_rank_ids) {
  28. MS_EXCEPTION_IF_NULL(graph);
  29. std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
  30. all_to_all_v_input.insert(all_to_all_v_input.end(), inputs.begin(), inputs.end());
  31. auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
  32. MS_EXCEPTION_IF_NULL(all_to_all_v);
  33. AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(send_rank_ids), all_to_all_v);
  34. AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(recv_rank_ids), all_to_all_v);
  35. AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>("default_group"), all_to_all_v);
  36. return all_to_all_v;
  37. }
  38. void SetOutputs(const CNodePtr &cnode, const std::vector<std::vector<size_t>> &shape,
  39. const std::vector<TypeId> &data_type) {
  40. AnfAlgo::SetOutputInferTypeAndShape(data_type, shape, cnode.get());
  41. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  42. builder.SetFusionType(kernel::FusionType::OPAQUE);
  43. builder.SetProcessor(kernel::Processor::AICORE);
  44. builder.SetKernelType(TBE_KERNEL);
  45. builder.SetInputsFormat(std::vector<std::string>(cnode->size() - 1, format_));
  46. builder.SetOutputsFormat(std::vector<std::string>(shape.size(), format_));
  47. builder.SetInputsDeviceType(std::vector<TypeId>(cnode->size() - 1, type_));
  48. builder.SetOutputsDeviceType(std::vector<TypeId>(shape.size(), type_));
  49. cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  50. AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
  51. }
  52. std::vector<AnfNodePtr> CreateInputs(const FuncGraphPtr &graph, const std::vector<std::vector<size_t>> &shape,
  53. const std::vector<TypeId> &data_type) {
  54. MS_EXCEPTION_IF_NULL(graph);
  55. if (shape.size() != data_type.size()) {
  56. return {};
  57. }
  58. std::vector<AnfNodePtr> res;
  59. for (size_t i = 0; i < shape.size(); ++i) {
  60. auto node = graph->NewCNode(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>("AnyNameOp"))});
  61. AnfAlgo::SetOutputInferTypeAndShape(std::vector<TypeId>{data_type[i]}, std::vector<std::vector<size_t>>{shape[i]},
  62. node.get());
  63. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  64. builder.SetFusionType(kernel::FusionType::OPAQUE);
  65. builder.SetProcessor(kernel::Processor::AICORE);
  66. builder.SetKernelType(TBE_KERNEL);
  67. builder.SetInputsFormat({format_});
  68. builder.SetOutputsFormat({format_});
  69. builder.SetInputsDeviceType({type_});
  70. builder.SetOutputsDeviceType({type_});
  71. node->set_kernel_info(std::make_shared<device::KernelInfo>());
  72. AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
  73. res.emplace_back(node);
  74. }
  75. return res;
  76. }
  77. TypeId type_ = TypeId::kNumberTypeInt32;
  78. std::string format_ = "NCHW";
  79. };
  80. /// Feature: AllToAllvCalcParam
  81. /// Description: on 2p, send to rank 1, and recv nothing
  82. /// Expectation: send count 0 1
  83. /// send offset 0 0
  84. /// recv count 0 0
  85. /// recv offset 0 0
  86. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_only_send) {
  87. auto graph = std::make_shared<FuncGraph>();
  88. ASSERT_TRUE(graph != nullptr);
  89. uint32_t rank_size = 2;
  90. std::vector<int64_t> send_rank_ids = {1};
  91. std::vector<int64_t> recv_rank_ids = {};
  92. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
  93. ASSERT_TRUE(alltoall != nullptr);
  94. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  95. AllToAllvCalcParam calc(alltoall, rank_size);
  96. ASSERT_NO_THROW(calc.CalcOpParam());
  97. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1}));
  98. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
  99. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0}));
  100. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0}));
  101. }
  102. /// Feature: AllToAllvCalcParam
  103. /// Description: on 2p, send nothing, and recv from rank 0 and rank 1
  104. /// Expectation: send count 0 0
  105. /// send offset 0 0
  106. /// recv count 1 1
  107. /// recv offset 0 128
  108. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_only_recv) {
  109. auto graph = std::make_shared<FuncGraph>();
  110. ASSERT_TRUE(graph != nullptr);
  111. uint32_t rank_size = 2;
  112. std::vector<int64_t> send_rank_ids = {};
  113. std::vector<int64_t> recv_rank_ids = {0, 1};
  114. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  115. ASSERT_TRUE(alltoall != nullptr);
  116. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  117. AllToAllvCalcParam calc(alltoall, rank_size);
  118. ASSERT_NO_THROW(calc.CalcOpParam());
  119. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 0}));
  120. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
  121. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1}));
  122. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128}));
  123. }
  124. /// Feature: AllToAllvCalcParam
  125. /// Description: on 4p, send to rank1,2,3, and recv nothing
  126. /// Expectation: send count 0 1 1 1
  127. /// send offset 0 0 128 256
  128. /// recv count 0 0 0 0
  129. /// recv offset 0 0 0 0
  130. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send) {
  131. auto graph = std::make_shared<FuncGraph>();
  132. ASSERT_TRUE(graph != nullptr);
  133. uint32_t rank_size = 4;
  134. std::vector<int64_t> send_rank_ids = {1, 2, 3};
  135. std::vector<int64_t> recv_rank_ids = {};
  136. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}}, {type_, type_, type_}), send_rank_ids,
  137. recv_rank_ids);
  138. ASSERT_TRUE(alltoall != nullptr);
  139. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  140. AllToAllvCalcParam calc(alltoall, rank_size);
  141. ASSERT_NO_THROW(calc.CalcOpParam());
  142. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 1, 1}));
  143. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0, 128, 256}));
  144. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
  145. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
  146. }
  147. /// Feature: AllToAllvCalcParam
  148. /// Description: on 4p, send to rank1,3, and recv nothing
  149. /// Expectation: send count 0 1 0 1
  150. /// send offset 0 0 128 128
  151. /// recv count 0 0 0 0
  152. /// recv offset 0 0 0 0
  153. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send_2) {
  154. auto graph = std::make_shared<FuncGraph>();
  155. ASSERT_TRUE(graph != nullptr);
  156. uint32_t rank_size = 4;
  157. std::vector<int64_t> send_rank_ids = {1, 3};
  158. std::vector<int64_t> recv_rank_ids = {};
  159. auto alltoall =
  160. CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}}, {type_, type_}), send_rank_ids, recv_rank_ids);
  161. ASSERT_TRUE(alltoall != nullptr);
  162. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  163. AllToAllvCalcParam calc(alltoall, rank_size);
  164. ASSERT_NO_THROW(calc.CalcOpParam());
  165. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 0, 1}));
  166. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0, 128, 128}));
  167. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
  168. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
  169. }
  170. /// Feature: AllToAllvCalcParam
  171. /// Description: on 2p, send to rank1, and recv from rank1
  172. /// Expectation: send count 0 1
  173. /// send offset 0 0
  174. /// recv count 0 1
  175. /// recv offset 0 0
  176. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_exchange) {
  177. auto graph = std::make_shared<FuncGraph>();
  178. ASSERT_TRUE(graph != nullptr);
  179. uint32_t rank_size = 2;
  180. std::vector<int64_t> send_rank_ids = {1};
  181. std::vector<int64_t> recv_rank_ids = {1};
  182. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
  183. ASSERT_TRUE(alltoall != nullptr);
  184. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}}, {type_}));
  185. AllToAllvCalcParam calc(alltoall, rank_size);
  186. ASSERT_NO_THROW(calc.CalcOpParam());
  187. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1}));
  188. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
  189. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 1}));
  190. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0}));
  191. }
  192. /// Feature: AllToAllvCalcParam
  193. /// Description: on 2p, send to rank0, and recv from rank0
  194. /// Expectation: send count 1 0
  195. /// send offset 0 128
  196. /// recv count 1 0
  197. /// recv offset 0 128
  198. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_send_to_self) {
  199. auto graph = std::make_shared<FuncGraph>();
  200. ASSERT_TRUE(graph != nullptr);
  201. uint32_t rank_size = 2;
  202. std::vector<int64_t> send_rank_ids = {0};
  203. std::vector<int64_t> recv_rank_ids = {0};
  204. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
  205. ASSERT_TRUE(alltoall != nullptr);
  206. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}}, {type_}));
  207. AllToAllvCalcParam calc(alltoall, rank_size);
  208. ASSERT_NO_THROW(calc.CalcOpParam());
  209. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 0}));
  210. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128}));
  211. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 0}));
  212. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128}));
  213. }
  214. /// Feature: AllToAllvCalcParam
  215. /// Description: on 4p, send to rank0123, and recv from rank0123
  216. /// Expectation: send count 1 1 1 1
  217. /// send offset 0 128 256 384
  218. /// recv count 1 1 1 1
  219. /// recv offset 0 128 256 384
  220. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_all_to_all) {
  221. auto graph = std::make_shared<FuncGraph>();
  222. ASSERT_TRUE(graph != nullptr);
  223. uint32_t rank_size = 4;
  224. std::vector<int64_t> send_rank_ids = {0, 1, 2, 3};
  225. std::vector<int64_t> recv_rank_ids = {0, 1, 2, 3};
  226. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}),
  227. send_rank_ids, recv_rank_ids);
  228. ASSERT_TRUE(alltoall != nullptr);
  229. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}));
  230. AllToAllvCalcParam calc(alltoall, rank_size);
  231. ASSERT_NO_THROW(calc.CalcOpParam());
  232. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  233. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 384}));
  234. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  235. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128, 256, 384}));
  236. }
  237. /// Feature: AllToAllvCalcParam
  238. /// Description: on 4p, send to rank0123, and recv from rank0123, but recv order is wrong
  239. /// Expectation: send count 1 1 1 1
  240. /// send offset 0 128 256 384
  241. /// recv count 1 1 1 1
  242. /// recv offset 256 128 384 0
  243. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_all_in_all_in_wrong_order) {
  244. auto graph = std::make_shared<FuncGraph>();
  245. ASSERT_TRUE(graph != nullptr);
  246. uint32_t rank_size = 4;
  247. std::vector<int64_t> send_rank_ids = {0, 1, 2, 3};
  248. std::vector<int64_t> recv_rank_ids = {3, 1, 0, 2};
  249. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}),
  250. send_rank_ids, recv_rank_ids);
  251. ASSERT_TRUE(alltoall != nullptr);
  252. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}));
  253. AllToAllvCalcParam calc(alltoall, rank_size);
  254. ASSERT_NO_THROW(calc.CalcOpParam());
  255. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  256. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 384}));
  257. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  258. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({256, 128, 384, 0}));
  259. }
  260. /// Feature: AllToAllvCalcParam
  261. /// Description: on 4p, send to rank123, and recv from nothing, but send order is wrong
  262. /// Expectation: send count 0 1 1 1
  263. /// send offset 0 128 256 0
  264. /// recv count 0 0 0 0
  265. /// recv offset 0 0 0 0
  266. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send_in_wrong_order) {
  267. auto graph = std::make_shared<FuncGraph>();
  268. ASSERT_TRUE(graph != nullptr);
  269. uint32_t rank_size = 4;
  270. std::vector<int64_t> send_rank_ids = {3, 1, 2};
  271. std::vector<int64_t> recv_rank_ids = {};
  272. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}}, {type_, type_, type_}), send_rank_ids,
  273. recv_rank_ids);
  274. ASSERT_TRUE(alltoall != nullptr);
  275. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  276. AllToAllvCalcParam calc(alltoall, rank_size);
  277. ASSERT_NO_THROW(calc.CalcOpParam());
  278. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 1, 1}));
  279. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 0}));
  280. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
  281. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
  282. }
  283. /// Feature: AllToAllvCalcParam
  284. /// Description: on 2p, rank id over valid range
  285. /// Expectation: throw exception
  286. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_invalid_rank_id) {
  287. auto graph = std::make_shared<FuncGraph>();
  288. ASSERT_TRUE(graph != nullptr);
  289. uint32_t rank_size = 2;
  290. std::vector<int64_t> send_rank_ids = {};
  291. std::vector<int64_t> recv_rank_ids = {0, 2};
  292. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  293. ASSERT_TRUE(alltoall != nullptr);
  294. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  295. AllToAllvCalcParam calc(alltoall, rank_size);
  296. ASSERT_ANY_THROW(calc.CalcOpParam());
  297. }
  298. /// Feature: AllToAllvCalcParam
  299. /// Description: on 2p, has 2 outputs but only 1 recv_rank_ids is set
  300. /// Expectation: throw exception
  301. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_invalid_rank_id_2) {
  302. auto graph = std::make_shared<FuncGraph>();
  303. ASSERT_TRUE(graph != nullptr);
  304. uint32_t rank_size = 2;
  305. std::vector<int64_t> send_rank_ids = {};
  306. std::vector<int64_t> recv_rank_ids = {0};
  307. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  308. ASSERT_TRUE(alltoall != nullptr);
  309. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  310. AllToAllvCalcParam calc(alltoall, rank_size);
  311. ASSERT_ANY_THROW(calc.CalcOpParam());
  312. }
  313. /// Feature: AllToAllvCalcParam
  314. /// Description: on 2p, rank id over valid range
  315. /// Expectation: throw exception
  316. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_wrong_order_and_invalid_rank_id) {
  317. auto graph = std::make_shared<FuncGraph>();
  318. ASSERT_TRUE(graph != nullptr);
  319. uint32_t rank_size = 2;
  320. std::vector<int64_t> send_rank_ids = {};
  321. std::vector<int64_t> recv_rank_ids = {2, 0};
  322. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  323. ASSERT_TRUE(alltoall != nullptr);
  324. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  325. AllToAllvCalcParam calc(alltoall, rank_size);
  326. ASSERT_ANY_THROW(calc.CalcOpParam());
  327. }
  328. } // namespace mindspore::hccl