You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comm_manager.cc 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/comm_manager.h"
  17. #include "utils/convert_utils.h"
  18. #ifndef NO_DLIB
  19. #include "runtime/hccl_adapter/hccl_adapter.h"
  20. #endif
  21. #if defined(ENABLE_GPU)
  22. #include "runtime/device/gpu/distribution/collective_init.h"
  23. using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer;
  24. using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc;
  25. using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc;
  26. using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc;
  27. using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc;
  28. #endif
  29. namespace mindspore {
  30. #ifndef NO_DLIB
  31. CommManager &CommManager::GetInstance() noexcept {
  32. static CommManager instance("hccl");
  33. return instance;
  34. }
  35. #define HCCL_RUN_CHECK(op_name, group, op) \
  36. do { \
  37. auto hccl_result = (op); \
  38. if (hccl_result != 0) { \
  39. MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \
  40. return false; \
  41. } \
  42. } while (0)
  43. #define HCCL_GROUP_CHECK_EMPTY(group) \
  44. do { \
  45. if (group.length() == 0) { \
  46. MS_LOG(ERROR) << "The length of group name should not be 0"; \
  47. return false; \
  48. } \
  49. } while (0)
  50. #define HCCL_GROUP_CHECK_IS_WORLD(group) \
  51. do { \
  52. if (group == "hccl_world_group") { \
  53. MS_LOG(ERROR) << "The group name should not be hccl_world_group"; \
  54. return false; \
  55. } \
  56. } while (0)
  57. bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
  58. auto rank_size = rank_id_list.size();
  59. HCCL_GROUP_CHECK_EMPTY(group);
  60. HCCL_GROUP_CHECK_IS_WORLD(group);
  61. HCCL_RUN_CHECK(string("create communicate group"), group,
  62. hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size),
  63. vector<unsigned int>(rank_id_list).data()));
  64. return true;
  65. }
  66. bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
  67. HCCL_GROUP_CHECK_EMPTY(group);
  68. HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id));
  69. return true;
  70. }
  71. bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
  72. HCCL_GROUP_CHECK_EMPTY(group);
  73. HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size));
  74. return true;
  75. }
  76. bool CommManager::DestroyGroup(const string &group) const {
  77. HCCL_GROUP_CHECK_EMPTY(group);
  78. HCCL_GROUP_CHECK_IS_WORLD(group);
  79. HCCL_RUN_CHECK(string("destroy communicate group"), group, hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group));
  80. return true;
  81. }
  82. #elif defined(ENABLE_GPU)
  83. CommManager &CommManager::GetInstance() noexcept {
  84. static CommManager instance("nccl");
  85. return instance;
  86. }
  87. bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
  88. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  89. if (!collective_handle_) {
  90. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  91. }
  92. MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
  93. auto create_comm_group_funcptr =
  94. reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
  95. MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
  96. bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
  97. if (!ret) {
  98. MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
  99. return ret;
  100. }
  101. return ret;
  102. }
  103. bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
  104. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  105. if (!collective_handle_) {
  106. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  107. }
  108. auto get_rank_id_funcptr =
  109. reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
  110. MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
  111. int rank = (*get_rank_id_funcptr)(group);
  112. *rank_id = static_cast<unsigned int>(rank);
  113. MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
  114. return true;
  115. }
  116. bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
  117. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  118. if (!collective_handle_) {
  119. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  120. }
  121. auto get_group_size_funcptr =
  122. reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
  123. MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
  124. int size = (*get_group_size_funcptr)(group);
  125. *rank_size = static_cast<unsigned int>(size);
  126. MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
  127. return true;
  128. }
  129. bool CommManager::DestroyGroup(const string &group) const {
  130. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  131. if (!collective_handle_) {
  132. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  133. }
  134. auto destroy_group_funcptr =
  135. reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
  136. MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
  137. bool ret = (*destroy_group_funcptr)(group);
  138. if (!ret) {
  139. MS_LOG(ERROR) << "Destroying group " << group << " failed.";
  140. return ret;
  141. }
  142. return ret;
  143. }
  144. #else
  145. CommManager &CommManager::GetInstance() noexcept {
  146. static CommManager instance("hccl");
  147. return instance;
  148. }
  149. bool CommManager::CreateGroupSync(const string &, const vector<unsigned int> &) const { return true; }
  150. bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; }
  151. bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
  152. *rank_size = NO_COMM_DLIB_RANK_SIZE;
  153. return true;
  154. }
  155. bool CommManager::DestroyGroup(const string &group) const { return true; }
  156. #endif
  157. } // namespace mindspore