Merge pull request !3214 from ZPaC/enable-nccl-operation-by-grouptags/v0.6.0-beta
| @@ -279,6 +279,9 @@ if (ENABLE_GPU) | |||
| ${CUDNN_PATH}/lib64/libcudnn.so | |||
| ${CUDA_PATH}/lib64/libcudart.so | |||
| ${CUDA_PATH}/lib64/stubs/libcuda.so) | |||
| if (ENABLE_MPI) | |||
| set_target_properties(_ms_mpi PROPERTIES INSTALL_RPATH ${ORIGIN_PATH}) | |||
| endif() | |||
| endif () | |||
| if (ENABLE_CPU) | |||
| @@ -99,5 +99,11 @@ MS_REG_GPU_KERNEL_TWO( | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| BroadcastOpGpuKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| BroadcastOpGpuKernel, int, int) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| BroadcastOpGpuKernel, int, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -96,9 +96,10 @@ class BroadcastOpGpuKernel : public GpuKernel { | |||
| std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { | |||
| {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, | |||
| {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, | |||
| {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, | |||
| {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, | |||
| {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, | |||
| {"FloorDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, | |||
| {"TensorAdd", BROADCAST_TYPE_ADD}, | |||
| }; | |||
| auto iter = kBroadcastTypeMap.find(kernel_name); | |||
| @@ -24,17 +24,28 @@ MS_REG_GPU_KERNEL_ONE( | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllReduce, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(AllReduce, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| AllGather, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(AllGather, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| NcclGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ReduceScatter, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| NcclGpuKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE(ReduceScatter, | |||
| KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| NcclGpuKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -70,9 +70,7 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto | |||
| mindspore::parallel::Group *const group) { | |||
| // it is simple to use size to determine whether it is a world group | |||
| uint32_t world_size = 0; | |||
| if (world_group_ != NCCL_WORLD_GROUP) { | |||
| (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); | |||
| } | |||
| (void)CommManager::GetInstance().GetRankSize(world_group_, &world_size); | |||
| if (devices.size() == world_size) { | |||
| auto it = groups_.find(world_group_); | |||
| @@ -55,6 +55,7 @@ if (ENABLE_GPU) | |||
| PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) | |||
| add_library(gpu_collective SHARED ${GPU_COLLECTIVE_SRCS}) | |||
| target_link_libraries(gpu_collective PRIVATE mindspore::ompi mindspore::nccl) | |||
| target_link_libraries(_ms_mpi PRIVATE gpu_collective) | |||
| endif () | |||
| # add_library(_mindspore_device_cuda_obj OBJECT ${CUDA_SRC_LIST}) | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_COLLECTIVE_COMMON_H_ | |||
| #include <nccl.h> | |||
| #include <sstream> | |||
| #include "pybind11/pybind11.h" | |||
| @@ -25,6 +26,12 @@ namespace device { | |||
| namespace gpu { | |||
| constexpr int MAX_HOSTNAME_LEN = 1024; | |||
| constexpr char NCCL_WORLD_GROUP[] = "nccl_world_group"; | |||
| struct NcclGroupInfo { | |||
| int size; | |||
| int rank; | |||
| ncclUniqueId unique_id; | |||
| ncclComm_t comm; | |||
| }; | |||
| #define CHECK_RET(expression, result, message) \ | |||
| { \ | |||
| auto ret = (expression); \ | |||
| @@ -14,58 +14,37 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <mpi.h> | |||
| #include <nccl.h> | |||
| #include <unistd.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <iostream> | |||
| #include <vector> | |||
| #include "runtime/device/gpu/distribution/mpi_wrapper.h" | |||
| #include "runtime/device/gpu/distribution/nccl_wrapper.h" | |||
| #include "runtime/device/gpu/distribution/collective_wrapper.h" | |||
| #ifndef EXPORT_WRAPPER | |||
| #define EXPORT_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| void InitMPI() { MPIWrapper::instance(); } | |||
| using MPIWrapper = mindspore::device::gpu::MPIWrapper; | |||
| using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; | |||
| int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } | |||
| extern "C" EXPORT_WRAPPER void InitMPI() { MPIWrapper::instance(); } | |||
| void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } | |||
| extern "C" EXPORT_WRAPPER int local_rank_id() { return MPIWrapper::instance().local_rank_id(); } | |||
| extern "C" EXPORT_WRAPPER void InitNCCLComm() { NCCLWrapper::instance().InitNCCLComm(); } | |||
| extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) { | |||
| bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks) { | |||
| return MPIWrapper::instance().CreateCommGroup(group_name, ranks); | |||
| } | |||
| extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name) { | |||
| return MPIWrapper::instance().GetRankIDByGroup(group_name); | |||
| } | |||
| int GetRankIDByGroup(const std::string &group_name) { return MPIWrapper::instance().GetRankIDByGroup(group_name); } | |||
| extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name) { | |||
| return MPIWrapper::instance().GetGroupSize(group_name); | |||
| } | |||
| int GetGroupSize(const std::string &group_name) { return MPIWrapper::instance().GetGroupSize(group_name); } | |||
| extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name) { | |||
| return MPIWrapper::instance().DestroyGroup(group_name); | |||
| } | |||
| bool DestroyGroup(const std::string &group_name) { return MPIWrapper::instance().DestroyGroup(group_name); } | |||
| extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, ncclRedOp_t reduce_type, | |||
| cudaStream_t stream) { | |||
| return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream); | |||
| ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { | |||
| return NCCLWrapper::instance().AllReduce(input_addr, output_addr, count, data_type, reduce_type, stream, group); | |||
| } | |||
| extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, cudaStream_t stream) { | |||
| return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream); | |||
| ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| cudaStream_t stream, const std::string &group) { | |||
| return NCCLWrapper::instance().AllGather(input_addr, output_addr, count, data_type, stream, group); | |||
| } | |||
| extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, ncclRedOp_t reduce_type, | |||
| cudaStream_t stream) { | |||
| return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream); | |||
| ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group) { | |||
| return NCCLWrapper::instance().ReduceScatter(input_addr, output_addr, count, data_type, reduce_type, stream, group); | |||
| } | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <mpi.h> | |||
| #include <nccl.h> | |||
| #include <unistd.h> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "runtime/device/gpu/distribution/mpi_wrapper.h" | |||
| #include "runtime/device/gpu/distribution/nccl_wrapper.h" | |||
| #ifndef EXPORT_WRAPPER | |||
| #define EXPORT_WRAPPER __attribute__((visibility("default"))) | |||
| #endif | |||
| using MPIWrapper = mindspore::device::gpu::MPIWrapper; | |||
| using NCCLWrapper = mindspore::device::gpu::NCCLWrapper; | |||
| extern "C" EXPORT_WRAPPER void InitMPI(); | |||
| extern "C" EXPORT_WRAPPER int local_rank_id(); | |||
| extern "C" EXPORT_WRAPPER void InitNCCLComm(); | |||
| extern "C" EXPORT_WRAPPER bool CreateCommGroup(const std::string &group_name, const std::vector<unsigned int> &ranks); | |||
| extern "C" EXPORT_WRAPPER int GetRankIDByGroup(const std::string &group_name); | |||
| extern "C" EXPORT_WRAPPER int GetGroupSize(const std::string &group_name); | |||
| extern "C" EXPORT_WRAPPER bool DestroyGroup(const std::string &group_name); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, | |||
| const std::string &group); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, cudaStream_t stream, | |||
| const std::string &group); | |||
| extern "C" EXPORT_WRAPPER ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, ncclRedOp_t reduce_type, | |||
| cudaStream_t stream, const std::string &group); | |||
| @@ -58,7 +58,7 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto | |||
| if (rank_id_ == ranks[0]) { | |||
| group_unique_id = NCCLWrapper::instance().nccl_unique_id(); | |||
| } | |||
| MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, ranks[0], mpi_group_comm); | |||
| MPI_Bcast(&group_unique_id, sizeof(ncclUniqueId), MPI_BYTE, 0, mpi_group_comm); | |||
| int group_rank[1]; | |||
| int global_rank[1] = {rank_id_}; | |||
| @@ -68,9 +68,8 @@ bool MPIWrapper::CreateCommGroup(const std::string &group_name, const std::vecto | |||
| return false; | |||
| } | |||
| ncclComm_t nccl_group_comm; | |||
| NCCLWrapper::instance().InitNCCLComm(&nccl_group_comm, ranks.size(), group_unique_id, group_rank[0]); | |||
| NCCLWrapper::instance().SetGroupNameToNCCLComm(group_name, nccl_group_comm); | |||
| NcclGroupInfo nccl_group = {static_cast<int>(ranks.size()), group_rank[0], group_unique_id, nullptr}; | |||
| NCCLWrapper::instance().AddGroupInfo(group_name, &nccl_group); | |||
| return true; | |||
| } | |||
| @@ -111,7 +110,6 @@ void MPIWrapper::Init() { | |||
| CHECK_RET(MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_), MPI_SUCCESS, "Failed to init mpi rank id."); | |||
| CHECK_RET(MPI_Comm_size(MPI_COMM_WORLD, &rank_size_), MPI_SUCCESS, "Failed to init mpi rank size."); | |||
| NCCLWrapper::instance().set_rank(rank_id_, rank_size_); | |||
| AssignLocalRankID(); | |||
| CHECK_RET(MPI_Comm_group(MPI_COMM_WORLD, &world_group_), MPI_SUCCESS, "Failed to get group of MPI_COMM_WORLD"); | |||
| @@ -123,7 +121,9 @@ void MPIWrapper::Init() { | |||
| } | |||
| CHECK_RET(MPI_Bcast(reinterpret_cast<void *>(&unique_id), sizeof(unique_id), MPI_BYTE, 0, MPI_COMM_WORLD), | |||
| MPI_SUCCESS, "Failed to broadcast nccl unique id."); | |||
| NCCLWrapper::instance().set_nccl_unique_id(unique_id); | |||
| NcclGroupInfo world_group = {rank_size_, rank_id_, unique_id, nullptr}; | |||
| NCCLWrapper::instance().AddGroupInfo(NCCL_WORLD_GROUP, &world_group); | |||
| return; | |||
| } | |||
| @@ -30,60 +30,58 @@ ncclUniqueId NCCLWrapper::nccl_unique_id() const { | |||
| return unique_id; | |||
| } | |||
| void NCCLWrapper::set_nccl_unique_id(ncclUniqueId unique_id) { unique_id_ = unique_id; } | |||
| void NCCLWrapper::set_rank(int rank_id, int rank_size) { | |||
| rank_id_ = rank_id; | |||
| rank_size_ = rank_size; | |||
| } | |||
| void NCCLWrapper::InitNCCLComm() { | |||
| CHECK_RET(ncclCommInitRank(&comm_, rank_size_, unique_id_, rank_id_), ncclSuccess, | |||
| "Failed to init nccl communicator."); | |||
| group_to_comm_map_[NCCL_WORLD_GROUP] = comm_; | |||
| } | |||
| void NCCLWrapper::InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank) { | |||
| CHECK_RET(ncclCommInitRank(comm, rank_size, unique_id, rank), ncclSuccess, "Failed to init nccl communicator."); | |||
| for (auto group : group_info_) { | |||
| std::string group_name = group.first; | |||
| NcclGroupInfo group_info = group.second; | |||
| CHECK_RET(ncclCommInitRank(&(group_info.comm), group_info.size, group_info.unique_id, group_info.rank), ncclSuccess, | |||
| "Failed to init nccl communicator for group " + group_name); | |||
| group_info_[group_name].comm = group_info.comm; | |||
| } | |||
| comm_init_done_ = true; | |||
| } | |||
| ncclResult_t NCCLWrapper::AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| ncclRedOp_t reduce_type, cudaStream_t stream, const std::string &group_name) { | |||
| CHECK_RET(group_to_comm_map_.count(group_name), 1, | |||
| CHECK_RET(group_info_.count(group_name), 1, | |||
| "Failed to find NCCL communicator for AllReduce by the group name " + group_name); | |||
| ncclComm_t group_comm = group_to_comm_map_[group_name]; | |||
| ncclComm_t group_comm = group_info_[group_name].comm; | |||
| return ncclAllReduce(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); | |||
| } | |||
| ncclResult_t NCCLWrapper::AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t data_type, | |||
| cudaStream_t stream, const std::string &group_name) { | |||
| CHECK_RET(group_to_comm_map_.count(group_name), 1, | |||
| CHECK_RET(group_info_.count(group_name), 1, | |||
| "Failed to find NCCL communicator for AllGather by the group name " + group_name); | |||
| ncclComm_t group_comm = group_to_comm_map_[group_name]; | |||
| ncclComm_t group_comm = group_info_[group_name].comm; | |||
| return ncclAllGather(input_addr, output_addr, count, data_type, group_comm, stream); | |||
| } | |||
| ncclResult_t NCCLWrapper::ReduceScatter(const void *input_addr, void *output_addr, size_t count, | |||
| ncclDataType_t data_type, ncclRedOp_t reduce_type, cudaStream_t stream, | |||
| const std::string &group_name) { | |||
| CHECK_RET(group_to_comm_map_.count(group_name), 1, | |||
| CHECK_RET(group_info_.count(group_name), 1, | |||
| "Failed to find NCCL communicator for ReduceScatter by the group name " + group_name); | |||
| ncclComm_t group_comm = group_to_comm_map_[group_name]; | |||
| ncclComm_t group_comm = group_info_[group_name].comm; | |||
| return ncclReduceScatter(input_addr, output_addr, count, data_type, reduce_type, group_comm, stream); | |||
| } | |||
| void NCCLWrapper::SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm) { | |||
| group_to_comm_map_[group_name] = comm; | |||
| void NCCLWrapper::AddGroupInfo(const std::string &group_name, NcclGroupInfo *group) { | |||
| if (comm_init_done_) { | |||
| CHECK_RET(ncclCommInitRank(&(group->comm), group->size, group->unique_id, group->rank), ncclSuccess, | |||
| "Failed to init nccl communicator for group " + group_name); | |||
| } | |||
| group_info_[group_name] = *group; | |||
| } | |||
| void NCCLWrapper::DestroyGroup(const std::string &group_name) { | |||
| auto group_iter = group_to_comm_map_.find(group_name); | |||
| if (group_iter == group_to_comm_map_.end()) { | |||
| auto group_iter = group_info_.find(group_name); | |||
| if (group_iter == group_info_.end()) { | |||
| return; | |||
| } | |||
| group_to_comm_map_.erase(group_iter); | |||
| ncclComm_t group_comm = group_iter->second; | |||
| ncclComm_t group_comm = group_iter->second.comm; | |||
| CHECK_RET(ncclCommDestroy(group_comm), ncclSuccess, "Failed to destroy NCCL communicator for " + group_name); | |||
| group_info_.erase(group_iter); | |||
| return; | |||
| } | |||
| } // namespace gpu | |||
| @@ -33,29 +33,23 @@ class NCCLWrapper { | |||
| NCCLWrapper &operator=(const NCCLWrapper &) = delete; | |||
| static NCCLWrapper &instance(); | |||
| ncclUniqueId nccl_unique_id() const; | |||
| void set_nccl_unique_id(ncclUniqueId unique_id); | |||
| void set_rank(int rank_id, int rank_size); | |||
| void InitNCCLComm(); | |||
| void InitNCCLComm(ncclComm_t *comm, int rank_size, ncclUniqueId unique_id, int rank); | |||
| ncclResult_t AllReduce(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, | |||
| ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| ncclResult_t AllGather(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, | |||
| cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| ncclResult_t ReduceScatter(const void *input_addr, void *output_addr, size_t count, ncclDataType_t datatype, | |||
| ncclRedOp_t op, cudaStream_t stream, const std::string &group_name = NCCL_WORLD_GROUP); | |||
| void SetGroupNameToNCCLComm(const std::string &group_name, const ncclComm_t comm); | |||
| void AddGroupInfo(const std::string &group_name, NcclGroupInfo *group); | |||
| void DestroyGroup(const std::string &group_name); | |||
| private: | |||
| NCCLWrapper() : rank_id_(-1), rank_size_(0) {} | |||
| NCCLWrapper() : comm_init_done_(false) {} | |||
| ~NCCLWrapper() = default; | |||
| private: | |||
| int rank_id_; | |||
| int rank_size_; | |||
| ncclUniqueId unique_id_; | |||
| ncclComm_t comm_; | |||
| std::map<std::string, ncclComm_t> group_to_comm_map_; | |||
| bool comm_init_done_; | |||
| std::map<std::string, NcclGroupInfo> group_info_; | |||
| }; | |||
| } // namespace gpu | |||
| } // namespace device | |||
| @@ -15,45 +15,24 @@ | |||
| */ | |||
| #include "runtime/device/gpu/mpi/mpi_initializer.h" | |||
| #include <dlfcn.h> | |||
| #include <mpi.h> | |||
| #include <pybind11/operators.h> | |||
| #include <iostream> | |||
| #include <string> | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| MPIInitializer::MPIInitializer() { | |||
| int init_flag = 0; | |||
| if (MPI_Initialized(&init_flag) != MPI_SUCCESS) { | |||
| return; | |||
| } | |||
| if (init_flag == 0) { | |||
| auto ret = MPI_Init(nullptr, nullptr); | |||
| if (ret != MPI_SUCCESS) { | |||
| return; | |||
| } | |||
| } | |||
| MPI_Comm_rank(MPI_COMM_WORLD, &rank_id_); | |||
| MPI_Comm_size(MPI_COMM_WORLD, &rank_size_); | |||
| } | |||
| MPIInitializer::~MPIInitializer() { | |||
| int finalized_flag = 0; | |||
| (void)MPI_Finalized(&finalized_flag); | |||
| if (finalized_flag == 0) { | |||
| (void)MPI_Finalize(); | |||
| } | |||
| } | |||
| MPIInitializer &MPIInitializer::GetInstance() { | |||
| static MPIInitializer instance; | |||
| return instance; | |||
| } | |||
| int MPIInitializer::get_rank_id() { return MPIInitializer::GetInstance().rank_id_; } | |||
| int MPIInitializer::get_rank_id(const std::string &group) { return GetRankIDByGroup(group); } | |||
| int MPIInitializer::get_rank_size() { return MPIInitializer::GetInstance().rank_size_; } | |||
| int MPIInitializer::get_rank_size(const std::string &group) { return GetGroupSize(group); } | |||
| PYBIND11_MODULE(_ms_mpi, mpi_initializer) { | |||
| mpi_initializer.doc() = "mindspore mpi python wrapper"; | |||
| @@ -17,6 +17,9 @@ | |||
| #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ | |||
| #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_MPI_MPI_INITIALIZER_H_ | |||
| #include <string> | |||
| #include "runtime/device/gpu/distribution/collective_wrapper.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| @@ -25,15 +28,12 @@ class MPIInitializer { | |||
| MPIInitializer(MPIInitializer const &) = delete; | |||
| MPIInitializer &operator=(const MPIInitializer &) = delete; | |||
| static MPIInitializer &GetInstance(); | |||
| static int get_rank_id(); | |||
| static int get_rank_size(); | |||
| static int get_rank_id(const std::string &group); | |||
| static int get_rank_size(const std::string &groups); | |||
| private: | |||
| MPIInitializer(); | |||
| ~MPIInitializer(); | |||
| int rank_id_; | |||
| int rank_size_; | |||
| MPIInitializer() = default; | |||
| ~MPIInitializer() = default; | |||
| }; | |||
| } // namespace gpu | |||
| } // namespace device | |||
| @@ -163,10 +163,7 @@ def _get_rank_helper(group, backend): | |||
| else: | |||
| rank_id = hccl.get_rank_id(group) | |||
| elif backend == Backend.NCCL: | |||
| if group == NCCL_WORLD_COMM_GROUP: | |||
| rank_id = mpi.get_rank_id() | |||
| else: | |||
| raise RuntimeError("Nccl doesn't support get_rank_id by user group now.") | |||
| rank_id = mpi.get_rank_id(group) | |||
| else: | |||
| raise ValueError("Invalid backend: '{}'".format(backend)) | |||
| return rank_id | |||
| @@ -225,10 +222,7 @@ def _get_size_helper(group, backend): | |||
| else: | |||
| size = hccl.get_rank_size(group) | |||
| elif backend == Backend.NCCL: | |||
| if group == NCCL_WORLD_COMM_GROUP: | |||
| size = mpi.get_rank_size() | |||
| else: | |||
| raise RuntimeError("Nccl doesn't support get_rank_size by user group now.") | |||
| size = mpi.get_rank_size(group) | |||
| else: | |||
| raise ValueError("Invalid backend: '{}'".format(backend)) | |||
| return size | |||
| @@ -22,6 +22,7 @@ equal_op_info = AkgGpuRegOp("Equal") \ | |||
| .output(0, "output") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ | |||
| .get_op_info() | |||