You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comm_manager.cc 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/comm_manager.h"
  17. #include "utils/convert_utils.h"
  18. #include "utils/ms_context.h"
  19. #include "frontend/parallel/context.h"
  20. #include "frontend/parallel/group_manager.h"
  21. #ifndef NO_DLIB
  22. #include "runtime/hccl_adapter/hccl_adapter.h"
  23. #include "hccl/hcom.h"
  24. #include "runtime/device/ascend/distribute/ascend_collective.h"
  25. #endif
  26. #if defined(ENABLE_GPU)
  27. #include "runtime/device/gpu/distribution/collective_init.h"
  28. using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer;
  29. using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc;
  30. using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc;
  31. using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc;
  32. using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc;
  33. #endif
  34. namespace mindspore {
  35. #ifndef NO_DLIB
  36. CommManager &CommManager::GetInstance() noexcept {
  37. static CommManager instance("hccl");
  38. return instance;
  39. }
  40. #define HCCL_RUN_CHECK(op_name, group, op) \
  41. do { \
  42. auto hccl_result = (op); \
  43. if (hccl_result != 0) { \
  44. MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \
  45. return false; \
  46. } \
  47. } while (0)
  48. #define HCCL_GROUP_CHECK_EMPTY(group) \
  49. do { \
  50. if (group.length() == 0) { \
  51. MS_LOG(ERROR) << "The length of group name should not be 0"; \
  52. return false; \
  53. } \
  54. } while (0)
  55. #define HCCL_GROUP_CHECK_IS_WORLD(group) \
  56. do { \
  57. if (group == "hccl_world_group") { \
  58. MS_LOG(ERROR) << "The group name should not be hccl_world_group"; \
  59. return false; \
  60. } \
  61. } while (0)
  62. bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
  63. auto rank_size = rank_id_list.size();
  64. HCCL_GROUP_CHECK_EMPTY(group);
  65. HCCL_GROUP_CHECK_IS_WORLD(group);
  66. auto context_ptr = MsContext::GetInstance();
  67. MS_EXCEPTION_IF_NULL(context_ptr);
  68. bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
  69. auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
  70. if (!is_task_sink && mode == kGraphMode) {
  71. HcclCollectiveGroup::instance().CreateCommGroup(group, rank_id_list);
  72. } else {
  73. HCCL_RUN_CHECK(string("create communicate group"), group,
  74. hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size),
  75. vector<unsigned int>(rank_id_list).data()));
  76. }
  77. return true;
  78. }
  79. bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
  80. auto context = MsContext::GetInstance();
  81. MS_EXCEPTION_IF_NULL(context);
  82. if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  83. HCCL_GROUP_CHECK_EMPTY(group);
  84. if (!context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
  85. *rank_id = static_cast<unsigned int>(HcclCollectiveGroup::instance().GetRankId(group));
  86. } else {
  87. HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id));
  88. }
  89. } else {
  90. HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(rank_id));
  91. }
  92. return true;
  93. }
  94. bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
  95. auto context = MsContext::GetInstance();
  96. MS_EXCEPTION_IF_NULL(context);
  97. if (context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  98. HCCL_GROUP_CHECK_EMPTY(group);
  99. if (!context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
  100. *rank_size = static_cast<unsigned int>(HcclCollectiveGroup::instance().GetRankSize(group));
  101. } else {
  102. HCCL_RUN_CHECK(string("get rank size"), group,
  103. hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size));
  104. }
  105. } else {
  106. HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(rank_size));
  107. }
  108. return true;
  109. }
  110. bool CommManager::DestroyGroup(const string &group) const {
  111. HCCL_GROUP_CHECK_EMPTY(group);
  112. HCCL_GROUP_CHECK_IS_WORLD(group);
  113. HCCL_RUN_CHECK(string("destroy communicate group"), group, hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group));
  114. return true;
  115. }
  116. #elif defined(ENABLE_GPU)
  117. CommManager &CommManager::GetInstance() noexcept {
  118. static CommManager instance("nccl");
  119. return instance;
  120. }
  121. bool CommManager::CreateGroupSync(const string &group, const vector<unsigned int> &rank_id_list) const {
  122. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  123. if (!collective_handle_) {
  124. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  125. }
  126. MS_LOG(INFO) << "Create communication group " << group << " by rank id list " << rank_id_list;
  127. auto create_comm_group_funcptr =
  128. reinterpret_cast<CreateCommGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "CreateCommGroup"));
  129. MS_EXCEPTION_IF_NULL(create_comm_group_funcptr);
  130. bool ret = (*create_comm_group_funcptr)(group, rank_id_list);
  131. if (!ret) {
  132. MS_LOG(ERROR) << "Creating group " << group << "for rank id list" << rank_id_list << "failed.";
  133. return ret;
  134. }
  135. return ret;
  136. }
  137. bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const {
  138. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  139. if (!collective_handle_) {
  140. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  141. }
  142. auto get_rank_id_funcptr =
  143. reinterpret_cast<GetRankIDByGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "GetRankIDByGroup"));
  144. MS_EXCEPTION_IF_NULL(get_rank_id_funcptr);
  145. int rank = (*get_rank_id_funcptr)(group);
  146. *rank_id = static_cast<unsigned int>(rank);
  147. MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group;
  148. return true;
  149. }
  150. bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
  151. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  152. if (!collective_handle_) {
  153. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  154. }
  155. auto get_group_size_funcptr =
  156. reinterpret_cast<GetGroupSizeFunc>(dlsym(const_cast<void *>(collective_handle_), "GetGroupSize"));
  157. MS_EXCEPTION_IF_NULL(get_group_size_funcptr);
  158. int size = (*get_group_size_funcptr)(group);
  159. *rank_size = static_cast<unsigned int>(size);
  160. MS_LOG(INFO) << "Group " << group << " size is " << *rank_size;
  161. return true;
  162. }
  163. bool CommManager::DestroyGroup(const string &group) const {
  164. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  165. if (!collective_handle_) {
  166. MS_LOG(EXCEPTION) << "GPU collective handle is not initialized.";
  167. }
  168. auto destroy_group_funcptr =
  169. reinterpret_cast<DestroyGroupFunc>(dlsym(const_cast<void *>(collective_handle_), "DestroyGroup"));
  170. MS_EXCEPTION_IF_NULL(destroy_group_funcptr);
  171. bool ret = (*destroy_group_funcptr)(group);
  172. if (!ret) {
  173. MS_LOG(ERROR) << "Destroying group " << group << " failed.";
  174. return ret;
  175. }
  176. return ret;
  177. }
  178. #else
  179. CommManager &CommManager::GetInstance() noexcept {
  180. static CommManager instance("hccl");
  181. return instance;
  182. }
  183. bool CommManager::CreateGroupSync(const string &, const vector<unsigned int> &) const { return true; }
  184. bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; }
  185. bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const {
  186. *rank_size = NO_COMM_DLIB_RANK_SIZE;
  187. return true;
  188. }
  189. bool CommManager::DestroyGroup(const string &group) const { return true; }
  190. #endif
  191. uint32_t GetRank() {
  192. uint32_t rank_id = 0;
  193. auto ms_context = MsContext::GetInstance();
  194. MS_EXCEPTION_IF_NULL(ms_context);
  195. std::string world_group;
  196. std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  197. if (backend == kAscendDevice) {
  198. world_group = parallel::HCCL_WORLD_GROUP;
  199. } else if (backend == kGPUDevice) {
  200. world_group = parallel::NCCL_WORLD_GROUP;
  201. } else {
  202. // Other backends like CPU not support parallel, return rank_id with default 0.
  203. return rank_id;
  204. }
  205. auto parallel_context = parallel::ParallelContext::GetInstance();
  206. MS_EXCEPTION_IF_NULL(parallel_context);
  207. if (parallel_context->parallel_mode() != parallel::STAND_ALONE) {
  208. #ifndef NO_DLIB
  209. // Check HCCL inited.
  210. if (!hccl::HcclAdapter::GetInstance().Inited()) {
  211. MS_LOG(DEBUG) << "HCCL not inited, return rank_id = 0";
  212. return rank_id;
  213. }
  214. #elif defined(ENABLE_GPU)
  215. // Check NCCL inited.
  216. if (!CollectiveInitializer::instance().collective_inited()) {
  217. MS_LOG(DEBUG) << "NCLL not inited, return rank_id = 0";
  218. return rank_id;
  219. }
  220. #endif
  221. if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
  222. MS_LOG(EXCEPTION) << "Get rank id failed.";
  223. }
  224. }
  225. return rank_id;
  226. }
  227. bool IsStandAlone() {
  228. auto parallel_context = parallel::ParallelContext::GetInstance();
  229. MS_EXCEPTION_IF_NULL(parallel_context);
  230. return parallel_context->parallel_mode() == parallel::STAND_ALONE;
  231. }
  232. } // namespace mindspore