You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

alltoall_fusion.cc 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/optimizer/gpu/alltoall_fusion.h"
  17. #include <vector>
  18. #include <string>
  19. #include "backend/session/anf_runtime_algorithm.h"
  20. #include "ir/primitive.h"
  21. #include "utils/utils.h"
  22. #include "runtime/device/gpu/kernel_info_setter.h"
  23. #include "backend/kernel_compiler/gpu/nccl/nccl_gpu_kernel.h"
  24. namespace mindspore {
  25. namespace opt {
  26. namespace {
  27. constexpr size_t kCNodePrimitiveIdx = 0;
  28. constexpr size_t kAllToAllInputIdx = 1;
  29. typedef std::vector<int> (*GetGroupRanks)(const std::string &);
  30. inline int64_t NormalizeDim(const std::vector<size_t> &shape, int64_t dim) {
  31. return dim < 0 ? SizeToLong(shape.size()) + dim : dim;
  32. }
  33. uint32_t GetRankSize(const std::string &group) {
  34. uint32_t rank_size;
  35. const void *collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
  36. MS_EXCEPTION_IF_NULL(collective_handle_);
  37. // Get group size
  38. auto get_group_size_funcptr =
  39. reinterpret_cast<GetGroupRanks>(dlsym(const_cast<void *>(collective_handle_), "GetGroupRanks"));
  40. MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
  41. std::vector<int> group_ranks = (*get_group_size_funcptr)(group);
  42. rank_size = group_ranks.size();
  43. return rank_size;
  44. }
  45. CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) {
  46. MS_EXCEPTION_IF_NULL(graph);
  47. MS_EXCEPTION_IF_NULL(all_to_all);
  48. int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
  49. int64_t split_dim = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitDim);
  50. if (all_to_all->size() <= kAllToAllInputIdx) {
  51. MS_LOG(EXCEPTION) << "Invalid cnode " << all_to_all->DebugString() << " input size " << all_to_all->size();
  52. }
  53. // Make a split CNode.
  54. auto all_to_all_input = all_to_all->input(kAllToAllInputIdx);
  55. std::vector<AnfNodePtr> split_input = {NewValueNode(std::make_shared<Primitive>(prim::kPrimSplit->name())),
  56. all_to_all_input};
  57. auto split = graph->NewCNode(split_input);
  58. MS_EXCEPTION_IF_NULL(split);
  59. // Judge validity of split_dim and shape
  60. auto dtype = AnfAlgo::GetOutputInferDataType(all_to_all_input, 0);
  61. auto shape = AnfAlgo::GetOutputInferShape(all_to_all_input, 0);
  62. split_dim = NormalizeDim(shape, split_dim);
  63. if (SizeToLong(shape.size()) <= split_dim) {
  64. MS_LOG(EXCEPTION) << "Invalid split dim " << split_dim << " is over the shape size " << shape.size();
  65. }
  66. if (split_count == 0 || shape[LongToSize(split_dim)] % split_count != 0) {
  67. MS_LOG(EXCEPTION) << "Invalid split count " << split_count << " cannot be divisible by shape[" << split_dim
  68. << "] = " << shape[LongToSize(split_dim)];
  69. }
  70. shape[LongToSize(split_dim)] /= split_count;
  71. // Set Split CNode outputs type and shape, and CNode attributes.
  72. std::vector<TypeId> dtypes(split_count, dtype);
  73. std::vector<std::vector<size_t>> shapes(split_count, shape);
  74. AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, split.get());
  75. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(split_dim), split);
  76. AnfAlgo::SetNodeAttr(kAttrOutputNum, MakeValue<int64_t>(split_count), split);
  77. return split;
  78. }
  79. CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) {
  80. MS_EXCEPTION_IF_NULL(graph);
  81. MS_EXCEPTION_IF_NULL(all_to_all);
  82. MS_EXCEPTION_IF_NULL(split);
  83. int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
  84. std::string group = AnfAlgo::GetNodeAttr<std::string>(all_to_all, kAttrGroup);
  85. std::vector<AnfNodePtr> split_outputs;
  86. CreateMultipleOutputsOfAnfNode(graph, split, split_count, &split_outputs);
  87. if (split_outputs.empty()) {
  88. MS_LOG(EXCEPTION) << "The node " << split->DebugString() << " should have at least one output, but got 0.";
  89. }
  90. // Make a all_to_all_v CNode.
  91. std::vector<AnfNodePtr> all_to_all_v_input = {NewValueNode(std::make_shared<Primitive>(kAllToAllVOpName))};
  92. all_to_all_v_input.insert(all_to_all_v_input.end(), split_outputs.begin(), split_outputs.end());
  93. auto all_to_all_v = graph->NewCNode(all_to_all_v_input);
  94. MS_EXCEPTION_IF_NULL(all_to_all_v);
  95. // Prepare dtypes, shapes and ranks vectors.
  96. auto single_shape = AnfAlgo::GetOutputInferShape(split_outputs[0], 0);
  97. auto single_type = AnfAlgo::GetOutputInferDataType(split_outputs[0], 0);
  98. std::vector<TypeId> dtypes(split_count, single_type);
  99. std::vector<std::vector<size_t>> shapes(split_count, single_shape);
  100. AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, all_to_all_v.get());
  101. uint32_t rank_size = GetRankSize(group);
  102. std::vector<int64_t> rank_ids(rank_size, 0);
  103. for (uint32_t i = 0; i < rank_size; ++i) {
  104. rank_ids[i] = static_cast<int64_t>(i);
  105. }
  106. // Set AllToAllv CNode outputs and attributes.
  107. AnfAlgo::SetNodeAttr(kAttrSendRankIds, MakeValue<std::vector<int64_t>>(rank_ids), all_to_all_v);
  108. AnfAlgo::SetNodeAttr(kAttrRecvRankIds, MakeValue<std::vector<int64_t>>(rank_ids), all_to_all_v);
  109. AnfAlgo::SetNodeAttr(kAttrGroup, MakeValue<std::string>(group), all_to_all_v);
  110. MS_LOG(INFO) << "Create AllToAllv success, split count " << split_count << ", rank size " << rank_size;
  111. return all_to_all_v;
  112. }
  113. CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) {
  114. MS_EXCEPTION_IF_NULL(graph);
  115. MS_EXCEPTION_IF_NULL(all_to_all);
  116. MS_EXCEPTION_IF_NULL(all_to_all_v);
  117. int64_t split_count = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrSplitCount);
  118. int64_t concat_dim = AnfAlgo::GetNodeAttr<int64_t>(all_to_all, kAttrConcatDim);
  119. std::vector<AnfNodePtr> all_to_all_v_outputs;
  120. CreateMultipleOutputsOfAnfNode(graph, all_to_all_v, split_count, &all_to_all_v_outputs);
  121. if (all_to_all_v_outputs.empty()) {
  122. MS_LOG(EXCEPTION) << "The node " << all_to_all_v->DebugString() << " should have at least one output, but got 0.";
  123. }
  124. // Make a Concat CNode.
  125. std::vector<AnfNodePtr> concat_input = {NewValueNode(std::make_shared<Primitive>(kConcatOpName))};
  126. concat_input.insert(concat_input.end(), all_to_all_v_outputs.begin(), all_to_all_v_outputs.end());
  127. auto concat = graph->NewCNode(concat_input);
  128. MS_EXCEPTION_IF_NULL(concat);
  129. // Judge validity of concat_dim.
  130. auto single_shape = AnfAlgo::GetOutputInferShape(all_to_all_v_outputs[0], 0);
  131. concat_dim = NormalizeDim(single_shape, concat_dim);
  132. if (LongToSize(concat_dim) >= single_shape.size()) {
  133. MS_LOG(EXCEPTION) << "Invalid concat dim " << concat_dim << " is greater than shape size " << single_shape.size();
  134. }
  135. // Set Concat CNode outputs and attributes.
  136. single_shape[LongToSize(concat_dim)] *= split_count;
  137. AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(all_to_all_v_outputs[0], 0)}, {single_shape},
  138. concat.get());
  139. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue<int64_t>(concat_dim), concat);
  140. AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(split_count), concat);
  141. std::vector<int64_t> dyn_input_size{split_count};
  142. AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat);
  143. return concat;
  144. }
  145. } // namespace
  146. const BaseRef AllToAllFusion::DefinePattern() const {
  147. return VectorRef({prim::kPrimAllToAll, std::make_shared<SeqVar>()});
  148. }
  149. const AnfNodePtr AllToAllFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
  150. MS_EXCEPTION_IF_NULL(graph);
  151. MS_EXCEPTION_IF_NULL(node);
  152. auto all_to_all = node->cast<CNodePtr>();
  153. MS_EXCEPTION_IF_NULL(all_to_all);
  154. // Step1: Split the AllToAll input Tensor into n_ranks parts along the AllToAll split_dim.
  155. auto split = CreateSplitNode(graph, all_to_all);
  156. // Step2: AllToAllv send and recv data to and from different rank.
  157. auto all_to_all_v = CreateAllToAllvNode(graph, all_to_all, split);
  158. // Step3: Concat all parts into one Tensor.
  159. auto concat = CreateConcatNode(graph, all_to_all, all_to_all_v);
  160. return concat;
  161. }
  162. } // namespace opt
  163. } // namespace mindspore