/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "utils/comm_manager.h" #include "utils/convert_utils.h" #include "utils/ms_context.h" #include "frontend/parallel/context.h" #include "frontend/parallel/group_manager.h" #ifndef NO_DLIB #include "runtime/hccl_adapter/hccl_adapter.h" #include "hccl/hcom.h" #include "runtime/device/ascend/distribute/ascend_collective.h" #endif #if defined(ENABLE_GPU) #include "runtime/device/gpu/distribution/collective_init.h" using CollectiveInitializer = mindspore::device::gpu::CollectiveInitializer; using CreateCommGroupFunc = mindspore::device::gpu::CreateCommGroupFunc; using GetRankIDByGroupFunc = mindspore::device::gpu::GetRankIDByGroupFunc; using GetGroupSizeFunc = mindspore::device::gpu::GetGroupSizeFunc; using DestroyGroupFunc = mindspore::device::gpu::DestroyGroupFunc; #endif namespace mindspore { #ifndef NO_DLIB CommManager &CommManager::GetInstance() noexcept { static CommManager instance("hccl"); return instance; } #define HCCL_RUN_CHECK(op_name, group, op) \ do { \ auto hccl_result = (op); \ if (hccl_result != 0) { \ MS_LOG(ERROR) << op_name << " failed: #" << group << "#"; \ return false; \ } \ } while (0) #define HCCL_GROUP_CHECK_EMPTY(group) \ do { \ if (group.length() == 0) { \ MS_LOG(ERROR) << "The length of group name should not be 0"; \ return false; \ } \ } while (0) #define HCCL_GROUP_CHECK_IS_WORLD(group) \ do { \ if (group == "hccl_world_group") { \ MS_LOG(ERROR) << "The group name should not be hccl_world_group"; \ return false; \ } \ } while (0) bool CommManager::CreateGroupSync(const string &group, const vector &rank_id_list) const { auto rank_size = rank_id_list.size(); HCCL_GROUP_CHECK_EMPTY(group); HCCL_GROUP_CHECK_IS_WORLD(group); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); bool is_task_sink = context_ptr->get_param(MS_CTX_ENABLE_TASK_SINK); auto mode = context_ptr->get_param(MS_CTX_EXECUTION_MODE); if (!is_task_sink && mode == kGraphMode) { HcclCollectiveGroup::instance().CreateCommGroup(group, rank_id_list); } else { HCCL_RUN_CHECK(string("create communicate group"), group, hccl::HcclAdapter::GetInstance().HcclCreateGroup(group, UlongToUint(rank_size), vector(rank_id_list).data())); } return true; } bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); if (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { HCCL_GROUP_CHECK_EMPTY(group); if (!context->get_param(MS_CTX_ENABLE_TASK_SINK)) { *rank_id = static_cast(HcclCollectiveGroup::instance().GetRankId(group)); } else { HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(group, rank_id)); } } else { HCCL_RUN_CHECK(string("get rank_id"), group, hccl::HcclAdapter::GetInstance().HcclGetRankId(rank_id)); } return true; } bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); if (context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { HCCL_GROUP_CHECK_EMPTY(group); if (!context->get_param(MS_CTX_ENABLE_TASK_SINK)) { *rank_size = static_cast(HcclCollectiveGroup::instance().GetRankSize(group)); } else { HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(group, rank_size)); } } else { HCCL_RUN_CHECK(string("get rank size"), group, hccl::HcclAdapter::GetInstance().HcclGetRankSize(rank_size)); } return true; } bool CommManager::DestroyGroup(const string &group) const { HCCL_GROUP_CHECK_EMPTY(group); HCCL_GROUP_CHECK_IS_WORLD(group); HCCL_RUN_CHECK(string("destroy communicate group"), group, hccl::HcclAdapter::GetInstance().HcclDestroyGroup(group)); return true; } #elif defined(ENABLE_GPU) CommManager &CommManager::GetInstance() noexcept { static CommManager instance("nccl"); return instance; } bool CommManager::CreateGroupSync(const string &group, const vector &rank_id_list) const { bool ret = CollectiveInitializer::instance().CreateCommunicationGroup(group, rank_id_list); if (!ret) { MS_LOG(ERROR) << "Failed to create group " << group << " for rank id list " << rank_id_list; return ret; } MS_LOG(INFO) << "Successfully create group " << group << " for rank id list " << rank_id_list; return ret; } bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { *rank_id = CollectiveInitializer::instance().GetRankIDByGroup(group); MS_LOG(INFO) << "This process rank id is " << *rank_id << " in group " << group; return true; } bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const { *rank_size = CollectiveInitializer::instance().GetGroupSize(group); MS_LOG(INFO) << "Group " << group << " size is " << *rank_size; return true; } bool CommManager::DestroyGroup(const string &group) const { bool ret = CollectiveInitializer::instance().DestroyCommunicationGroup(group); if (!ret) { MS_LOG(ERROR) << "Failed to destroy group " << group; return ret; } MS_LOG(INFO) << "Successfully destroy group " << group; return ret; } #else CommManager &CommManager::GetInstance() noexcept { static CommManager instance("hccl"); return instance; } bool CommManager::CreateGroupSync(const string &, const vector &) const { return true; } bool CommManager::GetRankID(const string &group, unsigned int *rank_id) const { return true; } bool CommManager::GetRankSize(const string &group, unsigned int *rank_size) const { *rank_size = NO_COMM_DLIB_RANK_SIZE; return true; } bool CommManager::DestroyGroup(const string &group) const { return true; } #endif uint32_t GetRank() { uint32_t rank_id = 0; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); std::string world_group; std::string backend = ms_context->get_param(MS_CTX_DEVICE_TARGET); if (backend == kAscendDevice) { world_group = parallel::HCCL_WORLD_GROUP; } else if (backend == kGPUDevice) { world_group = parallel::NCCL_WORLD_GROUP; } else { // Other backends like CPU not support parallel, return rank_id with default 0. return rank_id; } auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); if (parallel_context->parallel_mode() != parallel::STAND_ALONE) { #ifndef NO_DLIB // Check HCCL inited. if (!hccl::HcclAdapter::GetInstance().Inited()) { MS_LOG(DEBUG) << "HCCL not inited, return rank_id = 0"; return rank_id; } #elif defined(ENABLE_GPU) // Check NCCL inited. if (!CollectiveInitializer::instance().collective_inited()) { MS_LOG(DEBUG) << "NCLL not inited, return rank_id = 0"; return rank_id; } #endif if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { MS_LOG(EXCEPTION) << "Get rank id failed."; } } return rank_id; } bool IsStandAlone() { auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); return parallel_context->parallel_mode() == parallel::STAND_ALONE; } } // namespace mindspore