You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hccl_adapter_test.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <memory>
  17. #include "common/common_test.h"
  18. #include "plugin/device/ascend/hal/hccl_adapter/all_to_all_v_calc_param.h"
  19. #include "backend/common/session/anf_runtime_algorithm.h"
  20. #include "include/common/utils/anfalgo.h"
  21. #include "mindspore/core/ir/dtype/type_id.h"
  22. namespace mindspore::hccl {
  23. class TestHcclAdapter : public UT::Common {
  24. public:
  25. TestHcclAdapter() {}
  26. protected:
  27. CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> inputs,
  28. const std::vector<int64_t> &send_rank_ids, const std::vector<int64_t> &recv_rank_ids) {
  29. MS_EXCEPTION_IF_NULL(graph);
  30. std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
  31. all_to_all_v_input.insert(all_to_all_v_input.end(), inputs.begin(), inputs.end());
  32. auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
  33. MS_EXCEPTION_IF_NULL(all_to_all_v);
  34. common::AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(send_rank_ids), all_to_all_v);
  35. common::AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(recv_rank_ids), all_to_all_v);
  36. common::AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>("default_group"), all_to_all_v);
  37. return all_to_all_v;
  38. }
  39. void SetOutputs(const CNodePtr &cnode, const std::vector<std::vector<size_t>> &shape,
  40. const std::vector<TypeId> &data_type) {
  41. common::AnfAlgo::SetOutputInferTypeAndShape(data_type, shape, cnode.get());
  42. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  43. builder.SetFusionType(kernel::FusionType::OPAQUE);
  44. builder.SetProcessor(kernel::Processor::AICORE);
  45. builder.SetKernelType(TBE_KERNEL);
  46. builder.SetInputsFormat(std::vector<std::string>(cnode->size() - 1, format_));
  47. builder.SetOutputsFormat(std::vector<std::string>(shape.size(), format_));
  48. builder.SetInputsDeviceType(std::vector<TypeId>(cnode->size() - 1, type_));
  49. builder.SetOutputsDeviceType(std::vector<TypeId>(shape.size(), type_));
  50. cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
  51. AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
  52. }
  53. std::vector<AnfNodePtr> CreateInputs(const FuncGraphPtr &graph, const std::vector<std::vector<size_t>> &shape,
  54. const std::vector<TypeId> &data_type) {
  55. MS_EXCEPTION_IF_NULL(graph);
  56. if (shape.size() != data_type.size()) {
  57. return {};
  58. }
  59. std::vector<AnfNodePtr> res;
  60. for (size_t i = 0; i < shape.size(); ++i) {
  61. auto node = graph->NewCNode(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>("AnyNameOp"))});
  62. common::AnfAlgo::SetOutputInferTypeAndShape(std::vector<TypeId>{data_type[i]},
  63. std::vector<std::vector<size_t>>{shape[i]}, node.get());
  64. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  65. builder.SetFusionType(kernel::FusionType::OPAQUE);
  66. builder.SetProcessor(kernel::Processor::AICORE);
  67. builder.SetKernelType(TBE_KERNEL);
  68. builder.SetInputsFormat({format_});
  69. builder.SetOutputsFormat({format_});
  70. builder.SetInputsDeviceType({type_});
  71. builder.SetOutputsDeviceType({type_});
  72. node->set_kernel_info(std::make_shared<device::KernelInfo>());
  73. AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
  74. res.emplace_back(node);
  75. }
  76. return res;
  77. }
  78. TypeId type_ = TypeId::kNumberTypeInt32;
  79. std::string format_ = "NCHW";
  80. };
  81. /// Feature: AllToAllvCalcParam
  82. /// Description: on 2p, send to rank 1, and recv nothing
  83. /// Expectation: send count 0 1
  84. /// send offset 0 0
  85. /// recv count 0 0
  86. /// recv offset 0 0
  87. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_only_send) {
  88. auto graph = std::make_shared<FuncGraph>();
  89. ASSERT_TRUE(graph != nullptr);
  90. uint32_t rank_size = 2;
  91. std::vector<int64_t> send_rank_ids = {1};
  92. std::vector<int64_t> recv_rank_ids = {};
  93. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
  94. ASSERT_TRUE(alltoall != nullptr);
  95. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  96. AllToAllvCalcParam calc(alltoall, rank_size);
  97. ASSERT_NO_THROW(calc.CalcOpParam());
  98. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1}));
  99. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
  100. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0}));
  101. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0}));
  102. }
  103. /// Feature: AllToAllvCalcParam
  104. /// Description: on 2p, send nothing, and recv from rank 0 and rank 1
  105. /// Expectation: send count 0 0
  106. /// send offset 0 0
  107. /// recv count 1 1
  108. /// recv offset 0 128
  109. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_only_recv) {
  110. auto graph = std::make_shared<FuncGraph>();
  111. ASSERT_TRUE(graph != nullptr);
  112. uint32_t rank_size = 2;
  113. std::vector<int64_t> send_rank_ids = {};
  114. std::vector<int64_t> recv_rank_ids = {0, 1};
  115. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  116. ASSERT_TRUE(alltoall != nullptr);
  117. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  118. AllToAllvCalcParam calc(alltoall, rank_size);
  119. ASSERT_NO_THROW(calc.CalcOpParam());
  120. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 0}));
  121. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
  122. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1}));
  123. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128}));
  124. }
  125. /// Feature: AllToAllvCalcParam
  126. /// Description: on 4p, send to rank1,2,3, and recv nothing
  127. /// Expectation: send count 0 1 1 1
  128. /// send offset 0 0 128 256
  129. /// recv count 0 0 0 0
  130. /// recv offset 0 0 0 0
  131. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send) {
  132. auto graph = std::make_shared<FuncGraph>();
  133. ASSERT_TRUE(graph != nullptr);
  134. uint32_t rank_size = 4;
  135. std::vector<int64_t> send_rank_ids = {1, 2, 3};
  136. std::vector<int64_t> recv_rank_ids = {};
  137. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}}, {type_, type_, type_}), send_rank_ids,
  138. recv_rank_ids);
  139. ASSERT_TRUE(alltoall != nullptr);
  140. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  141. AllToAllvCalcParam calc(alltoall, rank_size);
  142. ASSERT_NO_THROW(calc.CalcOpParam());
  143. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 1, 1}));
  144. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0, 128, 256}));
  145. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
  146. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
  147. }
  148. /// Feature: AllToAllvCalcParam
  149. /// Description: on 4p, send to rank1,3, and recv nothing
  150. /// Expectation: send count 0 1 0 1
  151. /// send offset 0 0 128 128
  152. /// recv count 0 0 0 0
  153. /// recv offset 0 0 0 0
  154. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send_2) {
  155. auto graph = std::make_shared<FuncGraph>();
  156. ASSERT_TRUE(graph != nullptr);
  157. uint32_t rank_size = 4;
  158. std::vector<int64_t> send_rank_ids = {1, 3};
  159. std::vector<int64_t> recv_rank_ids = {};
  160. auto alltoall =
  161. CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}}, {type_, type_}), send_rank_ids, recv_rank_ids);
  162. ASSERT_TRUE(alltoall != nullptr);
  163. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  164. AllToAllvCalcParam calc(alltoall, rank_size);
  165. ASSERT_NO_THROW(calc.CalcOpParam());
  166. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 0, 1}));
  167. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0, 128, 128}));
  168. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
  169. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
  170. }
  171. /// Feature: AllToAllvCalcParam
  172. /// Description: on 2p, send to rank1, and recv from rank1
  173. /// Expectation: send count 0 1
  174. /// send offset 0 0
  175. /// recv count 0 1
  176. /// recv offset 0 0
  177. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_exchange) {
  178. auto graph = std::make_shared<FuncGraph>();
  179. ASSERT_TRUE(graph != nullptr);
  180. uint32_t rank_size = 2;
  181. std::vector<int64_t> send_rank_ids = {1};
  182. std::vector<int64_t> recv_rank_ids = {1};
  183. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
  184. ASSERT_TRUE(alltoall != nullptr);
  185. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}}, {type_}));
  186. AllToAllvCalcParam calc(alltoall, rank_size);
  187. ASSERT_NO_THROW(calc.CalcOpParam());
  188. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1}));
  189. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 0}));
  190. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 1}));
  191. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0}));
  192. }
  193. /// Feature: AllToAllvCalcParam
  194. /// Description: on 2p, send to rank0, and recv from rank0
  195. /// Expectation: send count 1 0
  196. /// send offset 0 128
  197. /// recv count 1 0
  198. /// recv offset 0 128
  199. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_send_to_self) {
  200. auto graph = std::make_shared<FuncGraph>();
  201. ASSERT_TRUE(graph != nullptr);
  202. uint32_t rank_size = 2;
  203. std::vector<int64_t> send_rank_ids = {0};
  204. std::vector<int64_t> recv_rank_ids = {0};
  205. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}}, {type_}), send_rank_ids, recv_rank_ids);
  206. ASSERT_TRUE(alltoall != nullptr);
  207. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}}, {type_}));
  208. AllToAllvCalcParam calc(alltoall, rank_size);
  209. ASSERT_NO_THROW(calc.CalcOpParam());
  210. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 0}));
  211. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128}));
  212. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 0}));
  213. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128}));
  214. }
  215. /// Feature: AllToAllvCalcParam
  216. /// Description: on 4p, send to rank0123, and recv from rank0123
  217. /// Expectation: send count 1 1 1 1
  218. /// send offset 0 128 256 384
  219. /// recv count 1 1 1 1
  220. /// recv offset 0 128 256 384
  221. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_all_to_all) {
  222. auto graph = std::make_shared<FuncGraph>();
  223. ASSERT_TRUE(graph != nullptr);
  224. uint32_t rank_size = 4;
  225. std::vector<int64_t> send_rank_ids = {0, 1, 2, 3};
  226. std::vector<int64_t> recv_rank_ids = {0, 1, 2, 3};
  227. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}),
  228. send_rank_ids, recv_rank_ids);
  229. ASSERT_TRUE(alltoall != nullptr);
  230. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}));
  231. AllToAllvCalcParam calc(alltoall, rank_size);
  232. ASSERT_NO_THROW(calc.CalcOpParam());
  233. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  234. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 384}));
  235. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  236. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 128, 256, 384}));
  237. }
  238. /// Feature: AllToAllvCalcParam
  239. /// Description: on 4p, send to rank0123, and recv from rank0123, but recv order is wrong
  240. /// Expectation: send count 1 1 1 1
  241. /// send offset 0 128 256 384
  242. /// recv count 1 1 1 1
  243. /// recv offset 256 128 384 0
  244. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_all_in_all_in_wrong_order) {
  245. auto graph = std::make_shared<FuncGraph>();
  246. ASSERT_TRUE(graph != nullptr);
  247. uint32_t rank_size = 4;
  248. std::vector<int64_t> send_rank_ids = {0, 1, 2, 3};
  249. std::vector<int64_t> recv_rank_ids = {3, 1, 0, 2};
  250. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}),
  251. send_rank_ids, recv_rank_ids);
  252. ASSERT_TRUE(alltoall != nullptr);
  253. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}, {1}, {1}}, {type_, type_, type_, type_}));
  254. AllToAllvCalcParam calc(alltoall, rank_size);
  255. ASSERT_NO_THROW(calc.CalcOpParam());
  256. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  257. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 384}));
  258. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({1, 1, 1, 1}));
  259. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({256, 128, 384, 0}));
  260. }
  261. /// Feature: AllToAllvCalcParam
  262. /// Description: on 4p, send to rank123, and recv from nothing, but send order is wrong
  263. /// Expectation: send count 0 1 1 1
  264. /// send offset 0 128 256 0
  265. /// recv count 0 0 0 0
  266. /// recv offset 0 0 0 0
  267. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_4p_only_send_in_wrong_order) {
  268. auto graph = std::make_shared<FuncGraph>();
  269. ASSERT_TRUE(graph != nullptr);
  270. uint32_t rank_size = 4;
  271. std::vector<int64_t> send_rank_ids = {3, 1, 2};
  272. std::vector<int64_t> recv_rank_ids = {};
  273. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {{1}, {1}, {1}}, {type_, type_, type_}), send_rank_ids,
  274. recv_rank_ids);
  275. ASSERT_TRUE(alltoall != nullptr);
  276. ASSERT_NO_THROW(SetOutputs(alltoall, {}, {}));
  277. AllToAllvCalcParam calc(alltoall, rank_size);
  278. ASSERT_NO_THROW(calc.CalcOpParam());
  279. EXPECT_EQ(calc.GetSendCounts(), std::vector<int64_t>({0, 1, 1, 1}));
  280. EXPECT_EQ(calc.GetSendDispls(), std::vector<int64_t>({0, 128, 256, 0}));
  281. EXPECT_EQ(calc.GetRecvCounts(), std::vector<int64_t>({0, 0, 0, 0}));
  282. EXPECT_EQ(calc.GetRecvDispls(), std::vector<int64_t>({0, 0, 0, 0}));
  283. }
  284. /// Feature: AllToAllvCalcParam
  285. /// Description: on 2p, rank id over valid range
  286. /// Expectation: throw exception
  287. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_invalid_rank_id) {
  288. auto graph = std::make_shared<FuncGraph>();
  289. ASSERT_TRUE(graph != nullptr);
  290. uint32_t rank_size = 2;
  291. std::vector<int64_t> send_rank_ids = {};
  292. std::vector<int64_t> recv_rank_ids = {0, 2};
  293. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  294. ASSERT_TRUE(alltoall != nullptr);
  295. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  296. AllToAllvCalcParam calc(alltoall, rank_size);
  297. ASSERT_ANY_THROW(calc.CalcOpParam());
  298. }
  299. /// Feature: AllToAllvCalcParam
  300. /// Description: on 2p, has 2 outputs but only 1 recv_rank_ids is set
  301. /// Expectation: throw exception
  302. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_invalid_rank_id_2) {
  303. auto graph = std::make_shared<FuncGraph>();
  304. ASSERT_TRUE(graph != nullptr);
  305. uint32_t rank_size = 2;
  306. std::vector<int64_t> send_rank_ids = {};
  307. std::vector<int64_t> recv_rank_ids = {0};
  308. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  309. ASSERT_TRUE(alltoall != nullptr);
  310. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  311. AllToAllvCalcParam calc(alltoall, rank_size);
  312. ASSERT_ANY_THROW(calc.CalcOpParam());
  313. }
  314. /// Feature: AllToAllvCalcParam
  315. /// Description: on 2p, rank id over valid range
  316. /// Expectation: throw exception
  317. TEST_F(TestHcclAdapter, test_all_to_all_v_calc_param_2p_wrong_order_and_invalid_rank_id) {
  318. auto graph = std::make_shared<FuncGraph>();
  319. ASSERT_TRUE(graph != nullptr);
  320. uint32_t rank_size = 2;
  321. std::vector<int64_t> send_rank_ids = {};
  322. std::vector<int64_t> recv_rank_ids = {2, 0};
  323. auto alltoall = CreateAllToAllvNode(graph, CreateInputs(graph, {}, {}), send_rank_ids, recv_rank_ids);
  324. ASSERT_TRUE(alltoall != nullptr);
  325. ASSERT_NO_THROW(SetOutputs(alltoall, {{1}, {1}}, {type_, type_}));
  326. AllToAllvCalcParam calc(alltoall, rank_size);
  327. ASSERT_ANY_THROW(calc.CalcOpParam());
  328. }
  329. } // namespace mindspore::hccl